rossellison commited on
Commit
8a860df
1 Parent(s): 181927b

Upload 159 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +87 -0
  2. requirements.txt +17 -0
  3. stylegan3-fun/.github/FUNDING.yml +3 -0
  4. stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md +35 -0
  5. stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
  6. stylegan3-fun/.gitignore +173 -0
  7. stylegan3-fun/__pycache__/legacy.cpython-311.pyc +0 -0
  8. stylegan3-fun/__pycache__/legacy.cpython-38.pyc +0 -0
  9. stylegan3-fun/__pycache__/legacy.cpython-39.pyc +0 -0
  10. stylegan3-fun/dnnlib/__init__.py +9 -0
  11. stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc +0 -0
  12. stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc +0 -0
  13. stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
  14. stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc +0 -0
  15. stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc +0 -0
  16. stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc +0 -0
  17. stylegan3-fun/dnnlib/util.py +491 -0
  18. stylegan3-fun/gui_utils/__init__.py +9 -0
  19. stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc +0 -0
  20. stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc +0 -0
  21. stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  22. stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc +0 -0
  23. stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc +0 -0
  24. stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc +0 -0
  25. stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc +0 -0
  26. stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc +0 -0
  27. stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc +0 -0
  28. stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc +0 -0
  29. stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc +0 -0
  30. stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc +0 -0
  31. stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-311.pyc +0 -0
  32. stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-38.pyc +0 -0
  33. stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-39.pyc +0 -0
  34. stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-311.pyc +0 -0
  35. stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-38.pyc +0 -0
  36. stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-39.pyc +0 -0
  37. stylegan3-fun/gui_utils/gl_utils.py +374 -0
  38. stylegan3-fun/gui_utils/glfw_window.py +229 -0
  39. stylegan3-fun/gui_utils/imgui_utils.py +169 -0
  40. stylegan3-fun/gui_utils/imgui_window.py +103 -0
  41. stylegan3-fun/gui_utils/text_utils.py +123 -0
  42. stylegan3-fun/legacy.py +346 -0
  43. stylegan3-fun/metrics/__init__.py +9 -0
  44. stylegan3-fun/metrics/equivariance.py +267 -0
  45. stylegan3-fun/metrics/frechet_inception_distance.py +41 -0
  46. stylegan3-fun/metrics/inception_score.py +38 -0
  47. stylegan3-fun/metrics/kernel_inception_distance.py +46 -0
  48. stylegan3-fun/metrics/metric_main.py +153 -0
  49. stylegan3-fun/metrics/metric_utils.py +279 -0
  50. stylegan3-fun/metrics/perceptual_path_length.py +125 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import List, Optional, Tuple, Union
5
+ import random
6
+
7
+ sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located
8
+
9
+ import numpy as np
10
+ import PIL.Image
11
+ import torch
12
+ import streamlit as st
13
+ import dnnlib
14
+ import legacy
15
+
16
+
17
+ def parse_range(s: Union[str, List]) -> List[int]:
18
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
19
+
20
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
21
+ '''
22
+ if isinstance(s, list): return s
23
+ ranges = []
24
+ range_re = re.compile(r'^(\d+)-(\d+)$')
25
+ for p in s.split(','):
26
+ m = range_re.match(p)
27
+ if m:
28
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
29
+ else:
30
+ ranges.append(int(p))
31
+ return ranges
32
+
33
+ def make_transform(translate: Tuple[float,float], angle: float):
34
+ m = np.eye(3)
35
+ s = np.sin(angle/360.0*np.pi*2)
36
+ c = np.cos(angle/360.0*np.pi*2)
37
+ m[0][0] = c
38
+ m[0][1] = s
39
+ m[0][2] = translate[0]
40
+ m[1][0] = -s
41
+ m[1][1] = c
42
+ m[1][2] = translate[1]
43
+ return m
44
+
45
+ def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]):
46
+ print('Loading networks from "%s"...' % network_pkl)
47
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+ with open(network_pkl, 'rb') as f:
49
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
50
+
51
+ # Labels.
52
+ label = torch.zeros([1, G.c_dim], device=device)
53
+ if G.c_dim != 0:
54
+ if class_idx is None:
55
+ raise Exception('Must specify class label when using a conditional network')
56
+ label[:, class_idx] = 1
57
+
58
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
59
+
60
+ if hasattr(G.synthesis, 'input'):
61
+ m = make_transform(translate, rotate)
62
+ m = np.linalg.inv(m)
63
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
64
+
65
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
66
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
67
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
68
+ return img
69
+
70
+ def main():
71
+ st.title('Kpop Face Generator')
72
+
73
+ st.write('Press the button below to generate a new image:')
74
+ if st.button('Generate'):
75
+ network_pkl = 'kpopGG.pkl'
76
+ seed = random.randint(0, 99999)
77
+ truncation_psi = 0.45
78
+ noise_mode = 'const'
79
+ translate = (0.0, 0.0)
80
+ rotate = 0.0
81
+ class_idx = None
82
+
83
+ image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx)
84
+ st.image(image)
85
+
86
+ if __name__ == "__main__":
87
+ main()
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ click==8.1.3
2
+ ffmpeg_python==0.2.0
3
+ glfw==2.2.0
4
+ imgui==1.3.0
5
+ matplotlib==3.4.2
6
+ numpy==1.21.6
7
+ Pillow==9.5.0
8
+ Pillow==10.0.0
9
+ PyOpenGL==3.1.5
10
+ PyOpenGL==3.1.7
11
+ Requests==2.31.0
12
+ scipy==1.7.3
13
+ setuptools==58.1.0
14
+ setuptools==60.2.0
15
+ streamlit==1.25.0
16
+ torch==2.0.1
17
+ training==0.0.1
stylegan3-fun/.github/FUNDING.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: PDillis # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. In '...' directory, run command '...'
16
+ 2. See error (copy&paste full log, including exceptions and **stacktraces**).
17
+
18
+ Please copy&paste text instead of screenshots for better searchability.
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Desktop (please complete the following information):**
27
+ - OS: [e.g. Linux Ubuntu 20.04, Windows 10]
28
+ - PyTorch version (e.g., pytorch 1.9.0)
29
+ - CUDA toolkit version (e.g., CUDA 11.4)
30
+ - NVIDIA driver version
31
+ - GPU [e.g., Titan V, RTX 3090]
32
+ - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:21.08-py3)
33
+
34
+ **Additional context**
35
+ Add any other context about the problem here.
stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Describe alternatives you've considered**
17
+ A clear and concise description of any alternative solutions or features you've considered.
18
+
19
+ **Additional context**
20
+ Add any other context or screenshots about the feature request here.
stylegan3-fun/.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ */**/__pycache__/
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ .idea/
162
+
163
+ # Conda temp
164
+ .condatmp/
165
+
166
+ # SGAN specific folders
167
+ datasets/
168
+ dlatents/
169
+ out/
170
+ training-runs/
171
+ pretrained/
172
+ video/
173
+ _screenshots/
stylegan3-fun/__pycache__/legacy.cpython-311.pyc ADDED
Binary file (29 kB). View file
 
stylegan3-fun/__pycache__/legacy.cpython-38.pyc ADDED
Binary file (15.4 kB). View file
 
stylegan3-fun/__pycache__/legacy.cpython-39.pyc ADDED
Binary file (15.3 kB). View file
 
stylegan3-fun/dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (285 Bytes). View file
 
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (245 Bytes). View file
 
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (245 Bytes). View file
 
stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc ADDED
Binary file (26.1 kB). View file
 
stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc ADDED
Binary file (14.1 kB). View file
 
stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
stylegan3-fun/dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
stylegan3-fun/gui_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (181 Bytes). View file
 
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (181 Bytes). View file
 
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc ADDED
Binary file (7.8 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc ADDED
Binary file (7.78 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc ADDED
Binary file (9.73 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc ADDED
Binary file (5.1 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc ADDED
Binary file (5.19 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-311.pyc ADDED
Binary file (7.11 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-38.pyc ADDED
Binary file (4 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-39.pyc ADDED
Binary file (3.99 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-311.pyc ADDED
Binary file (9.87 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-38.pyc ADDED
Binary file (5.19 kB). View file
 
stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-39.pyc ADDED
Binary file (5.15 kB). View file
 
stylegan3-fun/gui_utils/gl_utils.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import functools
11
+ import contextlib
12
+ import numpy as np
13
+ import OpenGL.GL as gl
14
+ import OpenGL.GL.ARB.texture_float
15
+ import dnnlib
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def init_egl():
20
+ assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
21
+ import OpenGL.EGL as egl
22
+ import ctypes
23
+
24
+ # Initialize EGL.
25
+ display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
26
+ assert display != egl.EGL_NO_DISPLAY
27
+ major = ctypes.c_int32()
28
+ minor = ctypes.c_int32()
29
+ ok = egl.eglInitialize(display, major, minor)
30
+ assert ok
31
+ assert major.value * 10 + minor.value >= 14
32
+
33
+ # Choose config.
34
+ config_attribs = [
35
+ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
36
+ egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
37
+ egl.EGL_NONE
38
+ ]
39
+ configs = (ctypes.c_int32 * 1)()
40
+ num_configs = ctypes.c_int32()
41
+ ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
42
+ assert ok
43
+ assert num_configs.value == 1
44
+ config = configs[0]
45
+
46
+ # Create dummy pbuffer surface.
47
+ surface_attribs = [
48
+ egl.EGL_WIDTH, 1,
49
+ egl.EGL_HEIGHT, 1,
50
+ egl.EGL_NONE
51
+ ]
52
+ surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
53
+ assert surface != egl.EGL_NO_SURFACE
54
+
55
+ # Setup GL context.
56
+ ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
57
+ assert ok
58
+ context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
59
+ assert context != egl.EGL_NO_CONTEXT
60
+ ok = egl.eglMakeCurrent(display, surface, surface, context)
61
+ assert ok
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ _texture_formats = {
66
+ ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
67
+ ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
68
+ ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
69
+ ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
70
+ ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
71
+ ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
72
+ ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
73
+ ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
74
+ }
75
+
76
+ def get_texture_format(dtype, channels):
77
+ return _texture_formats[(np.dtype(dtype).name, int(channels))]
78
+
79
+ #----------------------------------------------------------------------------
80
+
81
+ def prepare_texture_data(image):
82
+ image = np.asarray(image)
83
+ if image.ndim == 2:
84
+ image = image[:, :, np.newaxis]
85
+ if image.dtype.name == 'float64':
86
+ image = image.astype('float32')
87
+ return image
88
+
89
+ #----------------------------------------------------------------------------
90
+
91
+ def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
92
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
93
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
94
+ align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
95
+ image = prepare_texture_data(image)
96
+ height, width, channels = image.shape
97
+ size = zoom * [width, height]
98
+ pos = pos - size * align
99
+ if rint:
100
+ pos = np.rint(pos)
101
+ fmt = get_texture_format(image.dtype, channels)
102
+
103
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
104
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
105
+ gl.glRasterPos2f(pos[0], pos[1])
106
+ gl.glPixelZoom(zoom[0], -zoom[1])
107
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
108
+ gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
109
+ gl.glPopClientAttrib()
110
+ gl.glPopAttrib()
111
+
112
+ #----------------------------------------------------------------------------
113
+
114
+ def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
115
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
116
+ dtype = np.dtype(dtype)
117
+ fmt = get_texture_format(dtype, channels)
118
+ image = np.empty([height, width, channels], dtype=dtype)
119
+
120
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
121
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
122
+ gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
123
+ gl.glPopClientAttrib()
124
+ return np.flipud(image)
125
+
126
+ #----------------------------------------------------------------------------
127
+
128
+ class Texture:
129
+ def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
130
+ self.gl_id = None
131
+ self.bilinear = bilinear
132
+ self.mipmap = mipmap
133
+
134
+ # Determine size and dtype.
135
+ if image is not None:
136
+ image = prepare_texture_data(image)
137
+ self.height, self.width, self.channels = image.shape
138
+ self.dtype = image.dtype
139
+ else:
140
+ assert width is not None and height is not None
141
+ self.width = width
142
+ self.height = height
143
+ self.channels = channels if channels is not None else 3
144
+ self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
145
+
146
+ # Validate size and dtype.
147
+ assert isinstance(self.width, int) and self.width >= 0
148
+ assert isinstance(self.height, int) and self.height >= 0
149
+ assert isinstance(self.channels, int) and self.channels >= 1
150
+ assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
151
+
152
+ # Create texture object.
153
+ self.gl_id = gl.glGenTextures(1)
154
+ with self.bind():
155
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
156
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
157
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
158
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
159
+ self.update(image)
160
+
161
+ def delete(self):
162
+ if self.gl_id is not None:
163
+ gl.glDeleteTextures([self.gl_id])
164
+ self.gl_id = None
165
+
166
+ def __del__(self):
167
+ try:
168
+ self.delete()
169
+ except:
170
+ pass
171
+
172
+ @contextlib.contextmanager
173
+ def bind(self):
174
+ prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
175
+ gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
176
+ yield
177
+ gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
178
+
179
+ def update(self, image):
180
+ if image is not None:
181
+ image = prepare_texture_data(image)
182
+ assert self.is_compatible(image=image)
183
+ with self.bind():
184
+ fmt = get_texture_format(self.dtype, self.channels)
185
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
186
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
187
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
188
+ if self.mipmap:
189
+ gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
190
+ gl.glPopClientAttrib()
191
+
192
+ def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
193
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
194
+ size = zoom * [self.width, self.height]
195
+ with self.bind():
196
+ gl.glPushAttrib(gl.GL_ENABLE_BIT)
197
+ gl.glEnable(gl.GL_TEXTURE_2D)
198
+ draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
199
+ gl.glPopAttrib()
200
+
201
+ def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
202
+ if image is not None:
203
+ if image.ndim != 3:
204
+ return False
205
+ ih, iw, ic = image.shape
206
+ if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
207
+ return False
208
+ if width is not None and self.width != width:
209
+ return False
210
+ if height is not None and self.height != height:
211
+ return False
212
+ if channels is not None and self.channels != channels:
213
+ return False
214
+ if dtype is not None and self.dtype != dtype:
215
+ return False
216
+ return True
217
+
218
+ #----------------------------------------------------------------------------
219
+
220
+ class Framebuffer:
221
+ def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
222
+ self.texture = texture
223
+ self.gl_id = None
224
+ self.gl_color = None
225
+ self.gl_depth_stencil = None
226
+ self.msaa = msaa
227
+
228
+ # Determine size and dtype.
229
+ if texture is not None:
230
+ assert isinstance(self.texture, Texture)
231
+ self.width = texture.width
232
+ self.height = texture.height
233
+ self.channels = texture.channels
234
+ self.dtype = texture.dtype
235
+ else:
236
+ assert width is not None and height is not None
237
+ self.width = width
238
+ self.height = height
239
+ self.channels = channels if channels is not None else 4
240
+ self.dtype = np.dtype(dtype) if dtype is not None else np.float32
241
+
242
+ # Validate size and dtype.
243
+ assert isinstance(self.width, int) and self.width >= 0
244
+ assert isinstance(self.height, int) and self.height >= 0
245
+ assert isinstance(self.channels, int) and self.channels >= 1
246
+ assert width is None or width == self.width
247
+ assert height is None or height == self.height
248
+ assert channels is None or channels == self.channels
249
+ assert dtype is None or dtype == self.dtype
250
+
251
+ # Create framebuffer object.
252
+ self.gl_id = gl.glGenFramebuffers(1)
253
+ with self.bind():
254
+
255
+ # Setup color buffer.
256
+ if self.texture is not None:
257
+ assert self.msaa == 0
258
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
259
+ else:
260
+ fmt = get_texture_format(self.dtype, self.channels)
261
+ self.gl_color = gl.glGenRenderbuffers(1)
262
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
263
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
264
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
265
+
266
+ # Setup depth/stencil buffer.
267
+ self.gl_depth_stencil = gl.glGenRenderbuffers(1)
268
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
269
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
270
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
271
+
272
+ def delete(self):
273
+ if self.gl_id is not None:
274
+ gl.glDeleteFramebuffers([self.gl_id])
275
+ self.gl_id = None
276
+ if self.gl_color is not None:
277
+ gl.glDeleteRenderbuffers(1, [self.gl_color])
278
+ self.gl_color = None
279
+ if self.gl_depth_stencil is not None:
280
+ gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
281
+ self.gl_depth_stencil = None
282
+
283
+ def __del__(self):
284
+ try:
285
+ self.delete()
286
+ except:
287
+ pass
288
+
289
+ @contextlib.contextmanager
290
+ def bind(self):
291
+ prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
292
+ prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
293
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
294
+ if self.width is not None and self.height is not None:
295
+ gl.glViewport(0, 0, self.width, self.height)
296
+ yield
297
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
298
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
299
+
300
+ def blit(self, dst=None):
301
+ assert dst is None or isinstance(dst, Framebuffer)
302
+ with self.bind():
303
+ gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
304
+ gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
305
+
306
+ #----------------------------------------------------------------------------
307
+
308
+ def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
309
+ assert vertices.ndim == 2 and vertices.shape[1] == 2
310
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
311
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
312
+ color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
313
+ alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
314
+
315
+ gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
316
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
317
+ gl.glMatrixMode(gl.GL_MODELVIEW)
318
+ gl.glPushMatrix()
319
+
320
+ gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
321
+ gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
322
+ gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
323
+ gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
324
+ gl.glTranslate(pos[0], pos[1], 0)
325
+ gl.glScale(size[0], size[1], 1)
326
+ gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
327
+ gl.glDrawArrays(mode, 0, vertices.shape[0])
328
+
329
+ gl.glPopMatrix()
330
+ gl.glPopAttrib()
331
+ gl.glPopClientAttrib()
332
+
333
+ #----------------------------------------------------------------------------
334
+
335
+ def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
336
+ assert pos2 is None or size is None
337
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
338
+ pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
339
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
340
+ size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
341
+ pos = pos - size * align
342
+ if rint:
343
+ pos = np.rint(pos)
344
+ rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
345
+ rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
346
+ if np.min(rounding) == 0:
347
+ rounding *= 0
348
+ vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
349
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
350
+
351
+ @functools.lru_cache(maxsize=10000)
352
+ def _setup_rect(rx, ry):
353
+ t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
354
+ s = 1 - np.sin(t); c = 1 - np.cos(t)
355
+ x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
356
+ y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
357
+ v = np.stack([x, y], axis=-1).reshape(-1, 2)
358
+ return v.astype('float32')
359
+
360
+ #----------------------------------------------------------------------------
361
+
362
+ def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
363
+ hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
364
+ vertices = _setup_circle(float(hole))
365
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
366
+
367
+ @functools.lru_cache(maxsize=10000)
368
+ def _setup_circle(hole):
369
+ t = np.linspace(0, np.pi * 2, 128)
370
+ s = np.sin(t); c = np.cos(t)
371
+ v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
372
+ return v.astype('float32')
373
+
374
+ #----------------------------------------------------------------------------
stylegan3-fun/gui_utils/glfw_window.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import time
10
+ import glfw
11
+ import OpenGL.GL as gl
12
+ from . import gl_utils
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ class GlfwWindow: # pylint: disable=too-many-public-methods
17
+ def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
18
+ self._glfw_window = None
19
+ self._drawing_frame = False
20
+ self._frame_start_time = None
21
+ self._frame_delta = 0
22
+ self._fps_limit = None
23
+ self._vsync = None
24
+ self._skip_frames = 0
25
+ self._deferred_show = deferred_show
26
+ self._close_on_esc = close_on_esc
27
+ self._esc_pressed = False
28
+ self._drag_and_drop_paths = None
29
+ self._capture_next_frame = False
30
+ self._captured_frame = None
31
+
32
+ # Create window.
33
+ glfw.init()
34
+ glfw.window_hint(glfw.VISIBLE, False)
35
+ self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
36
+ self._attach_glfw_callbacks()
37
+ self.make_context_current()
38
+
39
+ # Adjust window.
40
+ self.set_vsync(False)
41
+ self.set_window_size(window_width, window_height)
42
+ if not self._deferred_show:
43
+ glfw.show_window(self._glfw_window)
44
+
45
+ def close(self):
46
+ if self._drawing_frame:
47
+ self.end_frame()
48
+ if self._glfw_window is not None:
49
+ glfw.destroy_window(self._glfw_window)
50
+ self._glfw_window = None
51
+ #glfw.terminate() # Commented out to play it nice with other glfw clients.
52
+
53
+ def __del__(self):
54
+ try:
55
+ self.close()
56
+ except:
57
+ pass
58
+
59
+ @property
60
+ def window_width(self):
61
+ return self.content_width
62
+
63
+ @property
64
+ def window_height(self):
65
+ return self.content_height + self.title_bar_height
66
+
67
+ @property
68
+ def content_width(self):
69
+ width, _height = glfw.get_window_size(self._glfw_window)
70
+ return width
71
+
72
+ @property
73
+ def content_height(self):
74
+ _width, height = glfw.get_window_size(self._glfw_window)
75
+ return height
76
+
77
+ @property
78
+ def title_bar_height(self):
79
+ _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
80
+ return top
81
+
82
+ @property
83
+ def monitor_width(self):
84
+ _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
85
+ return width
86
+
87
+ @property
88
+ def monitor_height(self):
89
+ _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
90
+ return height
91
+
92
+ @property
93
+ def frame_delta(self):
94
+ return self._frame_delta
95
+
96
+ def set_title(self, title):
97
+ glfw.set_window_title(self._glfw_window, title)
98
+
99
+ def set_window_size(self, width, height):
100
+ width = min(width, self.monitor_width)
101
+ height = min(height, self.monitor_height)
102
+ glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
103
+ if width == self.monitor_width and height == self.monitor_height:
104
+ self.maximize()
105
+
106
+ def set_content_size(self, width, height):
107
+ self.set_window_size(width, height + self.title_bar_height)
108
+
109
+ def maximize(self):
110
+ glfw.maximize_window(self._glfw_window)
111
+
112
+ def set_position(self, x, y):
113
+ glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
114
+
115
+ def center(self):
116
+ self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
117
+
118
+ def set_vsync(self, vsync):
119
+ vsync = bool(vsync)
120
+ if vsync != self._vsync:
121
+ glfw.swap_interval(1 if vsync else 0)
122
+ self._vsync = vsync
123
+
124
+ def set_fps_limit(self, fps_limit):
125
+ self._fps_limit = int(fps_limit)
126
+
127
+ def should_close(self):
128
+ return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
129
+
130
+ def skip_frame(self):
131
+ self.skip_frames(1)
132
+
133
+ def skip_frames(self, num): # Do not update window for the next N frames.
134
+ self._skip_frames = max(self._skip_frames, int(num))
135
+
136
+ def is_skipping_frames(self):
137
+ return self._skip_frames > 0
138
+
139
+ def capture_next_frame(self):
140
+ self._capture_next_frame = True
141
+
142
+ def pop_captured_frame(self):
143
+ frame = self._captured_frame
144
+ self._captured_frame = None
145
+ return frame
146
+
147
+ def pop_drag_and_drop_paths(self):
148
+ paths = self._drag_and_drop_paths
149
+ self._drag_and_drop_paths = None
150
+ return paths
151
+
152
+ def draw_frame(self): # To be overridden by subclass.
153
+ self.begin_frame()
154
+ # Rendering code goes here.
155
+ self.end_frame()
156
+
157
+ def make_context_current(self):
158
+ if self._glfw_window is not None:
159
+ glfw.make_context_current(self._glfw_window)
160
+
161
+ def begin_frame(self):
162
+ # End previous frame.
163
+ if self._drawing_frame:
164
+ self.end_frame()
165
+
166
+ # Apply FPS limit.
167
+ if self._frame_start_time is not None and self._fps_limit is not None:
168
+ delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
169
+ if delay > 0:
170
+ time.sleep(delay)
171
+ cur_time = time.perf_counter()
172
+ if self._frame_start_time is not None:
173
+ self._frame_delta = cur_time - self._frame_start_time
174
+ self._frame_start_time = cur_time
175
+
176
+ # Process events.
177
+ glfw.poll_events()
178
+
179
+ # Begin frame.
180
+ self._drawing_frame = True
181
+ self.make_context_current()
182
+
183
+ # Initialize GL state.
184
+ gl.glViewport(0, 0, self.content_width, self.content_height)
185
+ gl.glMatrixMode(gl.GL_PROJECTION)
186
+ gl.glLoadIdentity()
187
+ gl.glTranslate(-1, 1, 0)
188
+ gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
189
+ gl.glMatrixMode(gl.GL_MODELVIEW)
190
+ gl.glLoadIdentity()
191
+ gl.glEnable(gl.GL_BLEND)
192
+ gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
193
+
194
+ # Clear.
195
+ gl.glClearColor(0, 0, 0, 1)
196
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
197
+
198
+ def end_frame(self):
199
+ assert self._drawing_frame
200
+ self._drawing_frame = False
201
+
202
+ # Skip frames if requested.
203
+ if self._skip_frames > 0:
204
+ self._skip_frames -= 1
205
+ return
206
+
207
+ # Capture frame if requested.
208
+ if self._capture_next_frame:
209
+ self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
210
+ self._capture_next_frame = False
211
+
212
+ # Update window.
213
+ if self._deferred_show:
214
+ glfw.show_window(self._glfw_window)
215
+ self._deferred_show = False
216
+ glfw.swap_buffers(self._glfw_window)
217
+
218
+ def _attach_glfw_callbacks(self):
219
+ glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
220
+ glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
221
+
222
+ def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
223
+ if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
224
+ self._esc_pressed = True
225
+
226
+ def _glfw_drop_callback(self, _window, paths):
227
+ self._drag_and_drop_paths = paths
228
+
229
+ #----------------------------------------------------------------------------
stylegan3-fun/gui_utils/imgui_utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import contextlib
10
+ import imgui
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
15
+ s = imgui.get_style()
16
+ s.window_padding = [spacing, spacing]
17
+ s.item_spacing = [spacing, spacing]
18
+ s.item_inner_spacing = [spacing, spacing]
19
+ s.columns_min_spacing = spacing
20
+ s.indent_spacing = indent
21
+ s.scrollbar_size = scrollbar
22
+ s.frame_padding = [4, 3]
23
+ s.window_border_size = 1
24
+ s.child_border_size = 1
25
+ s.popup_border_size = 1
26
+ s.frame_border_size = 1
27
+ s.window_rounding = 0
28
+ s.child_rounding = 0
29
+ s.popup_rounding = 3
30
+ s.frame_rounding = 3
31
+ s.scrollbar_rounding = 3
32
+ s.grab_rounding = 3
33
+
34
+ getattr(imgui, f'style_colors_{color_scheme}')(s)
35
+ c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
36
+ c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
37
+ s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ @contextlib.contextmanager
42
+ def grayed_out(cond=True):
43
+ if cond:
44
+ s = imgui.get_style()
45
+ text = s.colors[imgui.COLOR_TEXT_DISABLED]
46
+ grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
47
+ back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
48
+ imgui.push_style_color(imgui.COLOR_TEXT, *text)
49
+ imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
50
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
51
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
52
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
53
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
54
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
55
+ imgui.push_style_color(imgui.COLOR_BUTTON, *back)
56
+ imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
57
+ imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
58
+ imgui.push_style_color(imgui.COLOR_HEADER, *back)
59
+ imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
60
+ imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
61
+ imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
62
+ yield
63
+ imgui.pop_style_color(14)
64
+ else:
65
+ yield
66
+
67
+ #----------------------------------------------------------------------------
68
+
69
+ @contextlib.contextmanager
70
+ def item_width(width=None):
71
+ if width is not None:
72
+ imgui.push_item_width(width)
73
+ yield
74
+ imgui.pop_item_width()
75
+ else:
76
+ yield
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ def scoped_by_object_id(method):
81
+ def decorator(self, *args, **kwargs):
82
+ imgui.push_id(str(id(self)))
83
+ res = method(self, *args, **kwargs)
84
+ imgui.pop_id()
85
+ return res
86
+ return decorator
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ def button(label, width=0, enabled=True):
91
+ with grayed_out(not enabled):
92
+ clicked = imgui.button(label, width=width)
93
+ clicked = clicked and enabled
94
+ return clicked
95
+
96
+ #----------------------------------------------------------------------------
97
+
98
+ def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
99
+ expanded = False
100
+ if show:
101
+ if default:
102
+ flags |= imgui.TREE_NODE_DEFAULT_OPEN
103
+ if not enabled:
104
+ flags |= imgui.TREE_NODE_LEAF
105
+ with grayed_out(not enabled):
106
+ expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
107
+ expanded = expanded and enabled
108
+ return expanded, visible
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ def popup_button(label, width=0, enabled=True):
113
+ if button(label, width, enabled):
114
+ imgui.open_popup(label)
115
+ opened = imgui.begin_popup(label)
116
+ return opened
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ def input_text(label, value, buffer_length, flags, width=None, help_text=''):
121
+ old_value = value
122
+ color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
123
+ if value == '':
124
+ color[-1] *= 0.5
125
+ with item_width(width):
126
+ imgui.push_style_color(imgui.COLOR_TEXT, *color)
127
+ value = value if value != '' else help_text
128
+ changed, value = imgui.input_text(label, value, buffer_length, flags)
129
+ value = value if value != help_text else ''
130
+ imgui.pop_style_color(1)
131
+ if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
132
+ changed = (value != old_value)
133
+ return changed, value
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ def drag_previous_control(enabled=True):
138
+ dragging = False
139
+ dx = 0
140
+ dy = 0
141
+ if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
142
+ if enabled:
143
+ dragging = True
144
+ dx, dy = imgui.get_mouse_drag_delta()
145
+ imgui.reset_mouse_drag_delta()
146
+ imgui.end_drag_drop_source()
147
+ return dragging, dx, dy
148
+
149
+ #----------------------------------------------------------------------------
150
+
151
+ def drag_button(label, width=0, enabled=True):
152
+ clicked = button(label, width=width, enabled=enabled)
153
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
154
+ return clicked, dragging, dx, dy
155
+
156
+ #----------------------------------------------------------------------------
157
+
158
+ def drag_hidden_window(label, x, y, width, height, enabled=True):
159
+ imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
160
+ imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
161
+ imgui.set_next_window_position(x, y)
162
+ imgui.set_next_window_size(width, height)
163
+ imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
164
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
165
+ imgui.end()
166
+ imgui.pop_style_color(2)
167
+ return dragging, dx, dy
168
+
169
+ #----------------------------------------------------------------------------
stylegan3-fun/gui_utils/imgui_window.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import imgui
11
+ import imgui.integrations.glfw
12
+
13
+ from . import glfw_window
14
+ from . import imgui_utils
15
+ from . import text_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ class ImguiWindow(glfw_window.GlfwWindow):
20
+ def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
21
+ if font is None:
22
+ font = text_utils.get_default_font()
23
+ font_sizes = {int(size) for size in font_sizes}
24
+ super().__init__(title=title, **glfw_kwargs)
25
+
26
+ # Init fields.
27
+ self._imgui_context = None
28
+ self._imgui_renderer = None
29
+ self._imgui_fonts = None
30
+ self._cur_font_size = max(font_sizes)
31
+
32
+ # Delete leftover imgui.ini to avoid unexpected behavior.
33
+ if os.path.isfile('imgui.ini'):
34
+ os.remove('imgui.ini')
35
+
36
+ # Init ImGui.
37
+ self._imgui_context = imgui.create_context()
38
+ self._imgui_renderer = _GlfwRenderer(self._glfw_window)
39
+ self._attach_glfw_callbacks()
40
+ imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
41
+ imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
42
+ self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
43
+ self._imgui_renderer.refresh_font_texture()
44
+
45
+ def close(self):
46
+ self.make_context_current()
47
+ self._imgui_fonts = None
48
+ if self._imgui_renderer is not None:
49
+ self._imgui_renderer.shutdown()
50
+ self._imgui_renderer = None
51
+ if self._imgui_context is not None:
52
+ #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
53
+ self._imgui_context = None
54
+ super().close()
55
+
56
+ def _glfw_key_callback(self, *args):
57
+ super()._glfw_key_callback(*args)
58
+ self._imgui_renderer.keyboard_callback(*args)
59
+
60
+ @property
61
+ def font_size(self):
62
+ return self._cur_font_size
63
+
64
+ @property
65
+ def spacing(self):
66
+ return round(self._cur_font_size * 0.4)
67
+
68
+ def set_font_size(self, target): # Applied on next frame.
69
+ self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
70
+
71
+ def begin_frame(self):
72
+ # Begin glfw frame.
73
+ super().begin_frame()
74
+
75
+ # Process imgui events.
76
+ self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
77
+ if self.content_width > 0 and self.content_height > 0:
78
+ self._imgui_renderer.process_inputs()
79
+
80
+ # Begin imgui frame.
81
+ imgui.new_frame()
82
+ imgui.push_font(self._imgui_fonts[self._cur_font_size])
83
+ imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
84
+
85
+ def end_frame(self):
86
+ imgui.pop_font()
87
+ imgui.render()
88
+ imgui.end_frame()
89
+ self._imgui_renderer.render(imgui.get_draw_data())
90
+ super().end_frame()
91
+
92
+ #----------------------------------------------------------------------------
93
+ # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
94
+
95
+ class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
96
+ def __init__(self, *args, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self.mouse_wheel_multiplier = 1
99
+
100
+ def scroll_callback(self, window, x_offset, y_offset):
101
+ self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
102
+
103
+ #----------------------------------------------------------------------------
stylegan3-fun/gui_utils/text_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import functools
10
+ from typing import Optional
11
+
12
+ import dnnlib
13
+ import numpy as np
14
+ import PIL.Image
15
+ import PIL.ImageFont
16
+ import scipy.ndimage
17
+
18
+ from . import gl_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def get_default_font():
23
+ url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
24
+ return dnnlib.util.open_url(url, return_filename=True)
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ @functools.lru_cache(maxsize=None)
29
+ def get_pil_font(font=None, size=32):
30
+ if font is None:
31
+ font = get_default_font()
32
+ return PIL.ImageFont.truetype(font=font, size=size)
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def get_array(string, *, dropshadow_radius: int=None, **kwargs):
37
+ if dropshadow_radius is not None:
38
+ offset_x = int(np.ceil(dropshadow_radius*2/3))
39
+ offset_y = int(np.ceil(dropshadow_radius*2/3))
40
+ return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
41
+ else:
42
+ return _get_array_priv(string, **kwargs)
43
+
44
+ @functools.lru_cache(maxsize=10000)
45
+ def _get_array_priv(
46
+ string: str, *,
47
+ size: int = 32,
48
+ max_width: Optional[int]=None,
49
+ max_height: Optional[int]=None,
50
+ min_size=10,
51
+ shrink_coef=0.8,
52
+ dropshadow_radius: int=None,
53
+ offset_x: int=None,
54
+ offset_y: int=None,
55
+ **kwargs
56
+ ):
57
+ cur_size = size
58
+ array = None
59
+ while True:
60
+ if dropshadow_radius is not None:
61
+ # separate implementation for dropshadow text rendering
62
+ array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
63
+ else:
64
+ array = _get_array_impl(string, size=cur_size, **kwargs)
65
+ height, width, _ = array.shape
66
+ if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
67
+ break
68
+ cur_size = max(int(cur_size * shrink_coef), min_size)
69
+ return array
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ @functools.lru_cache(maxsize=10000)
74
+ def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
75
+ pil_font = get_pil_font(font=font, size=size)
76
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
77
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
78
+ width = max(line.shape[1] for line in lines)
79
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
80
+ line_spacing = line_pad if line_pad is not None else size // 2
81
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
82
+ mask = np.concatenate(lines, axis=0)
83
+ alpha = mask
84
+ if outline > 0:
85
+ mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
86
+ alpha = mask.astype(np.float32) / 255
87
+ alpha = scipy.ndimage.gaussian_filter(alpha, outline)
88
+ alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
89
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
90
+ alpha = np.maximum(alpha, mask)
91
+ return np.stack([mask, alpha], axis=-1)
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ @functools.lru_cache(maxsize=10000)
96
+ def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
97
+ assert (offset_x > 0) and (offset_y > 0)
98
+ pil_font = get_pil_font(font=font, size=size)
99
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
100
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
101
+ width = max(line.shape[1] for line in lines)
102
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
103
+ line_spacing = line_pad if line_pad is not None else size // 2
104
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
105
+ mask = np.concatenate(lines, axis=0)
106
+ alpha = mask
107
+
108
+ mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
109
+ alpha = mask.astype(np.float32) / 255
110
+ alpha = scipy.ndimage.gaussian_filter(alpha, radius)
111
+ alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
112
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
113
+ alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
114
+ alpha = np.maximum(alpha, mask)
115
+ return np.stack([mask, alpha], axis=-1)
116
+
117
+ #----------------------------------------------------------------------------
118
+
119
+ @functools.lru_cache(maxsize=10000)
120
+ def get_texture(string, bilinear=True, mipmap=True, **kwargs):
121
+ return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
122
+
123
+ #----------------------------------------------------------------------------
stylegan3-fun/legacy.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Converting legacy network pickle into the new format."""
10
+
11
+ import click
12
+ import pickle
13
+ import re
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from torch_utils import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def load_network_pkl(f, force_fp16=False):
23
+ data = _LegacyUnpickler(f).load()
24
+ try:
25
+ # Legacy TensorFlow pickle => convert.
26
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
27
+ tf_G, tf_D, tf_Gs = data
28
+ G = convert_tf_generator(tf_G)
29
+ D = convert_tf_discriminator(tf_D)
30
+ G_ema = convert_tf_generator(tf_Gs)
31
+ data = dict(G=G, D=D, G_ema=G_ema)
32
+
33
+ # Add missing fields.
34
+ if 'training_set_kwargs' not in data:
35
+ data['training_set_kwargs'] = None
36
+ if 'augment_pipe' not in data:
37
+ data['augment_pipe'] = None
38
+
39
+ # Validate contents.
40
+ if 'G' in data:
41
+ assert isinstance(data['G'], torch.nn.Module)
42
+ if 'D' in data:
43
+ assert isinstance(data['D'], torch.nn.Module)
44
+ assert isinstance(data['G_ema'], torch.nn.Module)
45
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
46
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
47
+
48
+ # Force FP16.
49
+ if force_fp16:
50
+ for key in ['G', 'D', 'G_ema']:
51
+ old = data[key]
52
+ kwargs = copy.deepcopy(old.init_kwargs)
53
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
54
+ fp16_kwargs.num_fp16_res = 4
55
+ fp16_kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ except KeyError:
61
+ # Most likely a StyleGAN-NADA pkl, so pass and return data
62
+ pass
63
+ return data
64
+
65
+ #----------------------------------------------------------------------------
66
+
67
+ class _TFNetworkStub(dnnlib.EasyDict):
68
+ pass
69
+
70
+ class _LegacyUnpickler(pickle.Unpickler):
71
+ def find_class(self, module, name):
72
+ if module == 'dnnlib.tflib.network' and name == 'Network':
73
+ return _TFNetworkStub
74
+ return super().find_class(module, name)
75
+
76
+ #----------------------------------------------------------------------------
77
+
78
+ def _collect_tf_params(tf_net):
79
+ # pylint: disable=protected-access
80
+ tf_params = dict()
81
+ def recurse(prefix, tf_net):
82
+ for name, value in tf_net.variables:
83
+ tf_params[prefix + name] = value
84
+ for name, comp in tf_net.components.items():
85
+ recurse(prefix + name + '/', comp)
86
+ recurse('', tf_net)
87
+ return tf_params
88
+
89
+ #----------------------------------------------------------------------------
90
+
91
+ def _populate_module_params(module, *patterns):
92
+ for name, tensor in misc.named_params_and_buffers(module):
93
+ found = False
94
+ value = None
95
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
96
+ match = re.fullmatch(pattern, name)
97
+ if match:
98
+ found = True
99
+ if value_fn is not None:
100
+ value = value_fn(*match.groups())
101
+ break
102
+ try:
103
+ assert found
104
+ if value is not None:
105
+ tensor.copy_(torch.from_numpy(np.array(value)))
106
+ except:
107
+ print(name, list(tensor.shape))
108
+ raise
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ def convert_tf_generator(tf_G):
113
+ if tf_G.version < 4:
114
+ raise ValueError('TensorFlow pickle version too low')
115
+
116
+ # Collect kwargs.
117
+ tf_kwargs = tf_G.static_kwargs
118
+ known_kwargs = set()
119
+ def kwarg(tf_name, default=None, none=None):
120
+ known_kwargs.add(tf_name)
121
+ val = tf_kwargs.get(tf_name, default)
122
+ return val if val is not None else none
123
+
124
+ # Convert kwargs.
125
+ from training import networks_stylegan2
126
+ network_class = networks_stylegan2.Generator
127
+ kwargs = dnnlib.EasyDict(
128
+ z_dim = kwarg('latent_size', 512),
129
+ c_dim = kwarg('label_size', 0),
130
+ w_dim = kwarg('dlatent_size', 512),
131
+ img_resolution = kwarg('resolution', 1024),
132
+ img_channels = kwarg('num_channels', 3),
133
+ channel_base = kwarg('fmap_base', 16384) * 2,
134
+ channel_max = kwarg('fmap_max', 512),
135
+ num_fp16_res = kwarg('num_fp16_res', 0),
136
+ conv_clamp = kwarg('conv_clamp', None),
137
+ architecture = kwarg('architecture', 'skip'),
138
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
139
+ use_noise = kwarg('use_noise', True),
140
+ activation = kwarg('nonlinearity', 'lrelu'),
141
+ mapping_kwargs = dnnlib.EasyDict(
142
+ num_layers = kwarg('mapping_layers', 8),
143
+ embed_features = kwarg('label_fmaps', None),
144
+ layer_features = kwarg('mapping_fmaps', None),
145
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
146
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
147
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
148
+ ),
149
+ )
150
+ # Check for unknown kwargs.
151
+ kwarg('truncation_psi')
152
+ kwarg('truncation_cutoff')
153
+ kwarg('style_mixing_prob')
154
+ kwarg('structure')
155
+ kwarg('conditioning')
156
+ kwarg('fused_modconv')
157
+ kwarg('randomize_noise') # PR #173 by @cobanov, for StyleGAN2 models/TF 1.15
158
+ kwarg('resolution_h') # For transferring from --network=anime1024 --cfg=stylegan2-ext
159
+ kwarg('resolution_w') # For transferring from --network=anime1024 --cfg=stylegan2-ext
160
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
161
+ if len(unknown_kwargs) > 0:
162
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
163
+
164
+ # Collect params.
165
+ tf_params = _collect_tf_params(tf_G)
166
+
167
+ # Add changes from --cfg=stylegan2-ext; currently can't think of another way to do this
168
+ if 'resolution_h' and 'resolution_w' in tf_kwargs:
169
+ kwargs.channel_max = 1024
170
+ kwargs.w_dim = 1024
171
+ kwargs.z_dim = 1024
172
+ kwargs.mapping_kwargs.layer_features = 1024
173
+ kwargs.mapping_kwargs.num_layers = 4
174
+ kwargs.channel_base *= 2
175
+
176
+ for name, value in list(tf_params.items()):
177
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
178
+ if match:
179
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
180
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
181
+ kwargs.synthesis.kwargs.architecture = 'orig'
182
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
183
+
184
+ # Convert params.
185
+ G = network_class(**kwargs).eval().requires_grad_(False)
186
+ # pylint: disable=unnecessary-lambda
187
+ # pylint: disable=f-string-without-interpolation
188
+ _populate_module_params(G,
189
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
190
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
191
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
192
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
193
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
194
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
195
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
196
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
197
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
198
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
199
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
200
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
201
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
202
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
203
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
204
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
205
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
206
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
207
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
208
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
209
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
210
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
211
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
212
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
213
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
214
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
215
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
216
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
217
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
218
+ r'.*\.resample_filter', None,
219
+ r'.*\.act_filter', None,
220
+ )
221
+ return G
222
+
223
+ #----------------------------------------------------------------------------
224
+
225
+ def convert_tf_discriminator(tf_D):
226
+ if tf_D.version < 4:
227
+ raise ValueError('TensorFlow pickle version too low')
228
+
229
+ # Collect kwargs.
230
+ tf_kwargs = tf_D.static_kwargs
231
+ known_kwargs = set()
232
+ def kwarg(tf_name, default=None):
233
+ known_kwargs.add(tf_name)
234
+ return tf_kwargs.get(tf_name, default)
235
+
236
+ # Convert kwargs.
237
+ kwargs = dnnlib.EasyDict(
238
+ c_dim = kwarg('label_size', 0),
239
+ img_resolution = kwarg('resolution', 1024),
240
+ img_channels = kwarg('num_channels', 3),
241
+ architecture = kwarg('architecture', 'resnet'),
242
+ channel_base = kwarg('fmap_base', 16384) * 2,
243
+ channel_max = kwarg('fmap_max', 512),
244
+ num_fp16_res = kwarg('num_fp16_res', 0),
245
+ conv_clamp = kwarg('conv_clamp', None),
246
+ cmap_dim = kwarg('mapping_fmaps', None),
247
+ block_kwargs = dnnlib.EasyDict(
248
+ activation = kwarg('nonlinearity', 'lrelu'),
249
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
250
+ freeze_layers = kwarg('freeze_layers', 0),
251
+ ),
252
+ mapping_kwargs = dnnlib.EasyDict(
253
+ num_layers = kwarg('mapping_layers', 0),
254
+ embed_features = kwarg('mapping_fmaps', None),
255
+ layer_features = kwarg('mapping_fmaps', None),
256
+ activation = kwarg('nonlinearity', 'lrelu'),
257
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
258
+ ),
259
+ epilogue_kwargs = dnnlib.EasyDict(
260
+ mbstd_group_size = kwarg('mbstd_group_size', None),
261
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
262
+ activation = kwarg('nonlinearity', 'lrelu'),
263
+ ),
264
+ )
265
+
266
+ # Check for unknown kwargs.
267
+ kwarg('structure')
268
+ kwarg('conditioning')
269
+ kwarg('resolution_h') # For transferring from --network=anime1024 --cfg=stylegan2-ext
270
+ kwarg('resolution_w') # For transferring from --network=anime1024 --cfg=stylegan2-ext
271
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
272
+ if len(unknown_kwargs) > 0:
273
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
274
+
275
+ # Add changes from --cfg=stylegan2-ext. (I really can't think of another way tbh)
276
+ if 'resolution_h' and 'resolution_w' in tf_kwargs:
277
+ kwargs.epilogue_kwargs.mbstd_num_channels = 4
278
+
279
+ # Collect params.
280
+ tf_params = _collect_tf_params(tf_D)
281
+ for name, value in list(tf_params.items()):
282
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
283
+ if match:
284
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
285
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
286
+ kwargs.architecture = 'orig'
287
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
288
+
289
+ # Convert params.
290
+ from training import networks_stylegan2
291
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
292
+ # pylint: disable=unnecessary-lambda
293
+ # pylint: disable=f-string-without-interpolation
294
+ _populate_module_params(D,
295
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
296
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
297
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
298
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
299
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
300
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
301
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
302
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
303
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
304
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
305
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
306
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
307
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
308
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
309
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
310
+ r'.*\.resample_filter', None,
311
+ )
312
+ return D
313
+
314
+ #----------------------------------------------------------------------------
315
+
316
+ @click.command()
317
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
318
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
319
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
320
+ def convert_network_pickle(source, dest, force_fp16):
321
+ """Convert legacy network pickle into the native PyTorch format.
322
+
323
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
324
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
325
+
326
+ Example:
327
+
328
+ \b
329
+ python legacy.py \\
330
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
331
+ --dest=stylegan2-cat-config-f.pkl
332
+ """
333
+ print(f'Loading "{source}"...')
334
+ with dnnlib.util.open_url(source) as f:
335
+ data = load_network_pkl(f, force_fp16=force_fp16)
336
+ print(f'Saving "{dest}"...')
337
+ with open(dest, 'wb') as f:
338
+ pickle.dump(data, f)
339
+ print('Done.')
340
+
341
+ #----------------------------------------------------------------------------
342
+
343
+ if __name__ == "__main__":
344
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
345
+
346
+ #----------------------------------------------------------------------------
stylegan3-fun/metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
stylegan3-fun/metrics/equivariance.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
10
+ "Alias-Free Generative Adversarial Networks"."""
11
+
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import torch.fft
16
+ from torch_utils.ops import upfirdn2d
17
+ from . import metric_utils
18
+
19
+ #----------------------------------------------------------------------------
20
+ # Utilities.
21
+
22
+ def sinc(x):
23
+ y = (x * np.pi).abs()
24
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
25
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
26
+
27
+ def lanczos_window(x, a):
28
+ x = x.abs() / a
29
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
30
+
31
+ def rotation_matrix(angle):
32
+ angle = torch.as_tensor(angle).to(torch.float32)
33
+ mat = torch.eye(3, device=angle.device)
34
+ mat[0, 0] = angle.cos()
35
+ mat[0, 1] = angle.sin()
36
+ mat[1, 0] = -angle.sin()
37
+ mat[1, 1] = angle.cos()
38
+ return mat
39
+
40
+ #----------------------------------------------------------------------------
41
+ # Apply integer translation to a batch of 2D images. Corresponds to the
42
+ # operator T_x in Appendix E.1.
43
+
44
+ def apply_integer_translation(x, tx, ty):
45
+ _N, _C, H, W = x.shape
46
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
47
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
48
+ ix = tx.round().to(torch.int64)
49
+ iy = ty.round().to(torch.int64)
50
+
51
+ z = torch.zeros_like(x)
52
+ m = torch.zeros_like(x)
53
+ if abs(ix) < W and abs(iy) < H:
54
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
55
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
56
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
57
+ return z, m
58
+
59
+ #----------------------------------------------------------------------------
60
+ # Apply integer translation to a batch of 2D images. Corresponds to the
61
+ # operator T_x in Appendix E.2.
62
+
63
+ def apply_fractional_translation(x, tx, ty, a=3):
64
+ _N, _C, H, W = x.shape
65
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
66
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
67
+ ix = tx.floor().to(torch.int64)
68
+ iy = ty.floor().to(torch.int64)
69
+ fx = tx - ix
70
+ fy = ty - iy
71
+ b = a - 1
72
+
73
+ z = torch.zeros_like(x)
74
+ zx0 = max(ix - b, 0)
75
+ zy0 = max(iy - b, 0)
76
+ zx1 = min(ix + a, 0) + W
77
+ zy1 = min(iy + a, 0) + H
78
+ if zx0 < zx1 and zy0 < zy1:
79
+ taps = torch.arange(a * 2, device=x.device) - b
80
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
81
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
82
+ y = x
83
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
84
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
85
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
86
+ z[:, :, zy0:zy1, zx0:zx1] = y
87
+
88
+ m = torch.zeros_like(x)
89
+ mx0 = max(ix + a, 0)
90
+ my0 = max(iy + a, 0)
91
+ mx1 = min(ix - b, 0) + W
92
+ my1 = min(iy - b, 0) + H
93
+ if mx0 < mx1 and my0 < my1:
94
+ m[:, :, my0:my1, mx0:mx1] = 1
95
+ return z, m
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Construct an oriented low-pass filter that applies the appropriate
99
+ # bandlimit with respect to the input and output of the given affine 2D
100
+ # image transformation.
101
+
102
+ def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
103
+ assert a <= amax < aflt
104
+ mat = torch.as_tensor(mat).to(torch.float32)
105
+
106
+ # Construct 2D filter taps in input & output coordinate spaces.
107
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
108
+ yi, xi = torch.meshgrid(taps, taps)
109
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
110
+
111
+ # Convolution of two oriented 2D sinc filters.
112
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
113
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
114
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
115
+
116
+ # Convolution of two oriented 2D Lanczos windows.
117
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
118
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
119
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
120
+
121
+ # Construct windowed FIR filter.
122
+ f = f * w
123
+
124
+ # Finalize.
125
+ c = (aflt - amax) * up
126
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
127
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
128
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
129
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
130
+ return f
131
+
132
+ #----------------------------------------------------------------------------
133
+ # Apply the given affine transformation to a batch of 2D images.
134
+
135
+ def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
136
+ _N, _C, H, W = x.shape
137
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
138
+
139
+ # Construct filter.
140
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
141
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
142
+ p = f.shape[0] // 2
143
+
144
+ # Construct sampling grid.
145
+ theta = mat.inverse()
146
+ theta[:2, 2] *= 2
147
+ theta[0, 2] += 1 / up / W
148
+ theta[1, 2] += 1 / up / H
149
+ theta[0, :] *= W / (W + p / up * 2)
150
+ theta[1, :] *= H / (H + p / up * 2)
151
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
152
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
153
+
154
+ # Resample image.
155
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
156
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
157
+
158
+ # Form mask.
159
+ m = torch.zeros_like(y)
160
+ c = p * 2 + 1
161
+ m[:, :, c:-c, c:-c] = 1
162
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
163
+ return z, m
164
+
165
+ #----------------------------------------------------------------------------
166
+ # Apply fractional rotation to a batch of 2D images. Corresponds to the
167
+ # operator R_\alpha in Appendix E.3.
168
+
169
+ def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
170
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
171
+ mat = rotation_matrix(angle)
172
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
173
+
174
+ #----------------------------------------------------------------------------
175
+ # Modify the frequency content of a batch of 2D images as if they had undergo
176
+ # fractional rotation -- but without actually rotating them. Corresponds to
177
+ # the operator R^*_\alpha in Appendix E.3.
178
+
179
+ def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
180
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
181
+ mat = rotation_matrix(-angle)
182
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
183
+ y = upfirdn2d.filter2d(x=x, f=f)
184
+ m = torch.zeros_like(y)
185
+ c = f.shape[0] // 2
186
+ m[:, :, c:-c, c:-c] = 1
187
+ return y, m
188
+
189
+ #----------------------------------------------------------------------------
190
+ # Compute the selected equivariance metrics for the given generator.
191
+
192
+ def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
193
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
194
+
195
+ # Setup generator and labels.
196
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
197
+ I = torch.eye(3, device=opts.device)
198
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
199
+ if M is None:
200
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
201
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
202
+
203
+ # Sampling loop.
204
+ sums = None
205
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
206
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
207
+ progress.update(batch_start)
208
+ s = []
209
+
210
+ # Randomize noise buffers, if any.
211
+ for name, buf in G.named_buffers():
212
+ if name.endswith('.noise_const'):
213
+ buf.copy_(torch.randn_like(buf))
214
+
215
+ # Run mapping network.
216
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
217
+ c = next(c_iter)
218
+ ws = G.mapping(z=z, c=c)
219
+
220
+ # Generate reference image.
221
+ M[:] = I
222
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
223
+
224
+ # Integer translation (EQ-T).
225
+ if compute_eqt_int:
226
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
227
+ t = (t * G.img_resolution).round() / G.img_resolution
228
+ M[:] = I
229
+ M[:2, 2] = -t
230
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
231
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
232
+ s += [(ref - img).square() * mask, mask]
233
+
234
+ # Fractional translation (EQ-T_frac).
235
+ if compute_eqt_frac:
236
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
237
+ M[:] = I
238
+ M[:2, 2] = -t
239
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
240
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
241
+ s += [(ref - img).square() * mask, mask]
242
+
243
+ # Rotation (EQ-R).
244
+ if compute_eqr:
245
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
246
+ M[:] = rotation_matrix(-angle)
247
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
248
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
249
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
250
+ mask = ref_mask * pseudo_mask
251
+ s += [(ref - pseudo).square() * mask, mask]
252
+
253
+ # Accumulate results.
254
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
255
+ sums = sums + s if sums is not None else s
256
+ progress.update(num_samples)
257
+
258
+ # Compute PSNRs.
259
+ if opts.num_gpus > 1:
260
+ torch.distributed.all_reduce(sums)
261
+ sums = sums.cpu()
262
+ mses = sums[0::2] / sums[1::2]
263
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
264
+ psnrs = tuple(psnrs.numpy())
265
+ return psnrs[0] if len(psnrs) == 1 else psnrs
266
+
267
+ #----------------------------------------------------------------------------
stylegan3-fun/metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Frechet Inception Distance (FID) from the paper
10
+ "GANs trained by a two time-scale update rule converge to a local Nash
11
+ equilibrium". Matches the original implementation by Heusel et al. at
12
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
+
14
+ import numpy as np
15
+ import scipy.linalg
16
+ from . import metric_utils
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def compute_fid(opts, max_real, num_gen):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
23
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
+
25
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28
+
29
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32
+
33
+ if opts.rank != 0:
34
+ return float('nan')
35
+
36
+ m = np.square(mu_gen - mu_real).sum()
37
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39
+ return float(fid)
40
+
41
+ #----------------------------------------------------------------------------
stylegan3-fun/metrics/inception_score.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Inception Score (IS) from the paper "Improved techniques for training
10
+ GANs". Matches the original implementation by Salimans et al. at
11
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_is(opts, num_gen, num_splits):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
21
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
+
23
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ capture_all=True, max_items=num_gen).get_all()
26
+
27
+ if opts.rank != 0:
28
+ return float('nan'), float('nan')
29
+
30
+ scores = []
31
+ for i in range(num_splits):
32
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34
+ kl = np.mean(np.sum(kl, axis=1))
35
+ scores.append(np.exp(kl))
36
+ return float(np.mean(scores)), float(np.std(scores))
37
+
38
+ #----------------------------------------------------------------------------
stylegan3-fun/metrics/kernel_inception_distance.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
+ GANs". Matches the original implementation by Binkowski et al. at
11
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
21
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
+
23
+ real_features = metric_utils.compute_feature_stats_for_dataset(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26
+
27
+ gen_features = metric_utils.compute_feature_stats_for_generator(
28
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
+
31
+ if opts.rank != 0:
32
+ return float('nan')
33
+
34
+ n = real_features.shape[1]
35
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
+ t = 0
37
+ for _subset_idx in range(num_subsets):
38
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
+ b = (x @ y.T / n + 1) ** 3
42
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
+ kid = t / num_subsets / m
44
+ return float(kid)
45
+
46
+ #----------------------------------------------------------------------------
stylegan3-fun/metrics/metric_main.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Main API for computing and reporting quality metrics."""
10
+
11
+ import os
12
+ import time
13
+ import json
14
+ import torch
15
+ import dnnlib
16
+
17
+ from . import metric_utils
18
+ from . import frechet_inception_distance
19
+ from . import kernel_inception_distance
20
+ from . import precision_recall
21
+ from . import perceptual_path_length
22
+ from . import inception_score
23
+ from . import equivariance
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _metric_dict = dict() # name => fn
28
+
29
+ def register_metric(fn):
30
+ assert callable(fn)
31
+ _metric_dict[fn.__name__] = fn
32
+ return fn
33
+
34
+ def is_valid_metric(metric):
35
+ return metric in _metric_dict
36
+
37
+ def list_valid_metrics():
38
+ return list(_metric_dict.keys())
39
+
40
+ #----------------------------------------------------------------------------
41
+
42
+ def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
43
+ assert is_valid_metric(metric)
44
+ opts = metric_utils.MetricOptions(**kwargs)
45
+
46
+ # Calculate.
47
+ start_time = time.time()
48
+ results = _metric_dict[metric](opts)
49
+ total_time = time.time() - start_time
50
+
51
+ # Broadcast results.
52
+ for key, value in list(results.items()):
53
+ if opts.num_gpus > 1:
54
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
55
+ torch.distributed.broadcast(tensor=value, src=0)
56
+ value = float(value.cpu())
57
+ results[key] = value
58
+
59
+ # Decorate with metadata.
60
+ return dnnlib.EasyDict(
61
+ results = dnnlib.EasyDict(results),
62
+ metric = metric,
63
+ total_time = total_time,
64
+ total_time_str = dnnlib.util.format_time(total_time),
65
+ num_gpus = opts.num_gpus,
66
+ )
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
71
+ metric = result_dict['metric']
72
+ assert is_valid_metric(metric)
73
+ if run_dir is not None and snapshot_pkl is not None:
74
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
75
+
76
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
77
+ print(jsonl_line)
78
+ if run_dir is not None and os.path.isdir(run_dir):
79
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
80
+ f.write(jsonl_line + '\n')
81
+
82
+ #----------------------------------------------------------------------------
83
+ # Recommended metrics.
84
+
85
+ @register_metric
86
+ def fid50k_full(opts):
87
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
88
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
89
+ return dict(fid50k_full=fid)
90
+
91
+ @register_metric
92
+ def kid50k_full(opts):
93
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
94
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
95
+ return dict(kid50k_full=kid)
96
+
97
+ @register_metric
98
+ def pr50k3_full(opts):
99
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
100
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
101
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
102
+
103
+ @register_metric
104
+ def ppl2_wend(opts):
105
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
106
+ return dict(ppl2_wend=ppl)
107
+
108
+ @register_metric
109
+ def eqt50k_int(opts):
110
+ opts.G_kwargs.update(force_fp32=True)
111
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
112
+ return dict(eqt50k_int=psnr)
113
+
114
+ @register_metric
115
+ def eqt50k_frac(opts):
116
+ opts.G_kwargs.update(force_fp32=True)
117
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
118
+ return dict(eqt50k_frac=psnr)
119
+
120
+ @register_metric
121
+ def eqr50k(opts):
122
+ opts.G_kwargs.update(force_fp32=True)
123
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
124
+ return dict(eqr50k=psnr)
125
+
126
+ #----------------------------------------------------------------------------
127
+ # Legacy metrics.
128
+
129
+ @register_metric
130
+ def fid50k(opts):
131
+ opts.dataset_kwargs.update(max_size=None)
132
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
133
+ return dict(fid50k=fid)
134
+
135
+ @register_metric
136
+ def kid50k(opts):
137
+ opts.dataset_kwargs.update(max_size=None)
138
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
139
+ return dict(kid50k=kid)
140
+
141
+ @register_metric
142
+ def pr50k3(opts):
143
+ opts.dataset_kwargs.update(max_size=None)
144
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
145
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
146
+
147
+ @register_metric
148
+ def is50k(opts):
149
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
150
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
151
+ return dict(is50k_mean=mean, is50k_std=std)
152
+
153
+ #----------------------------------------------------------------------------
stylegan3-fun/metrics/metric_utils.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utilities used internally by the quality metrics."""
10
+
11
+ import os
12
+ import time
13
+ import hashlib
14
+ import pickle
15
+ import copy
16
+ import uuid
17
+ import numpy as np
18
+ import torch
19
+ import dnnlib
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ class MetricOptions:
24
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
25
+ assert 0 <= rank < num_gpus
26
+ self.G = G
27
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
28
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
29
+ self.num_gpus = num_gpus
30
+ self.rank = rank
31
+ self.device = device if device is not None else torch.device('cuda', rank)
32
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
33
+ self.cache = cache
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ _feature_detector_cache = dict()
38
+
39
+ def get_feature_detector_name(url):
40
+ return os.path.splitext(url.split('/')[-1])[0]
41
+
42
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
43
+ assert 0 <= rank < num_gpus
44
+ key = (url, device)
45
+ if key not in _feature_detector_cache:
46
+ is_leader = (rank == 0)
47
+ if not is_leader and num_gpus > 1:
48
+ torch.distributed.barrier() # leader goes first
49
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
50
+ _feature_detector_cache[key] = pickle.load(f).to(device)
51
+ if is_leader and num_gpus > 1:
52
+ torch.distributed.barrier() # others follow
53
+ return _feature_detector_cache[key]
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def iterate_random_labels(opts, batch_size):
58
+ if opts.G.c_dim == 0:
59
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
60
+ while True:
61
+ yield c
62
+ else:
63
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
64
+ while True:
65
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
66
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
67
+ yield c
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ class FeatureStats:
72
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
73
+ self.capture_all = capture_all
74
+ self.capture_mean_cov = capture_mean_cov
75
+ self.max_items = max_items
76
+ self.num_items = 0
77
+ self.num_features = None
78
+ self.all_features = None
79
+ self.raw_mean = None
80
+ self.raw_cov = None
81
+
82
+ def set_num_features(self, num_features):
83
+ if self.num_features is not None:
84
+ assert num_features == self.num_features
85
+ else:
86
+ self.num_features = num_features
87
+ self.all_features = []
88
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
89
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
90
+
91
+ def is_full(self):
92
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
93
+
94
+ def append(self, x):
95
+ x = np.asarray(x, dtype=np.float32)
96
+ assert x.ndim == 2
97
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
98
+ if self.num_items >= self.max_items:
99
+ return
100
+ x = x[:self.max_items - self.num_items]
101
+
102
+ self.set_num_features(x.shape[1])
103
+ self.num_items += x.shape[0]
104
+ if self.capture_all:
105
+ self.all_features.append(x)
106
+ if self.capture_mean_cov:
107
+ x64 = x.astype(np.float64)
108
+ self.raw_mean += x64.sum(axis=0)
109
+ self.raw_cov += x64.T @ x64
110
+
111
+ def append_torch(self, x, num_gpus=1, rank=0):
112
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
113
+ assert 0 <= rank < num_gpus
114
+ if num_gpus > 1:
115
+ ys = []
116
+ for src in range(num_gpus):
117
+ y = x.clone()
118
+ torch.distributed.broadcast(y, src=src)
119
+ ys.append(y)
120
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
121
+ self.append(x.cpu().numpy())
122
+
123
+ def get_all(self):
124
+ assert self.capture_all
125
+ return np.concatenate(self.all_features, axis=0)
126
+
127
+ def get_all_torch(self):
128
+ return torch.from_numpy(self.get_all())
129
+
130
+ def get_mean_cov(self):
131
+ assert self.capture_mean_cov
132
+ mean = self.raw_mean / self.num_items
133
+ cov = self.raw_cov / self.num_items
134
+ cov = cov - np.outer(mean, mean)
135
+ return mean, cov
136
+
137
+ def save(self, pkl_file):
138
+ with open(pkl_file, 'wb') as f:
139
+ pickle.dump(self.__dict__, f)
140
+
141
+ @staticmethod
142
+ def load(pkl_file):
143
+ with open(pkl_file, 'rb') as f:
144
+ s = dnnlib.EasyDict(pickle.load(f))
145
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
146
+ obj.__dict__.update(s)
147
+ return obj
148
+
149
+ #----------------------------------------------------------------------------
150
+
151
+ class ProgressMonitor:
152
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
153
+ self.tag = tag
154
+ self.num_items = num_items
155
+ self.verbose = verbose
156
+ self.flush_interval = flush_interval
157
+ self.progress_fn = progress_fn
158
+ self.pfn_lo = pfn_lo
159
+ self.pfn_hi = pfn_hi
160
+ self.pfn_total = pfn_total
161
+ self.start_time = time.time()
162
+ self.batch_time = self.start_time
163
+ self.batch_items = 0
164
+ if self.progress_fn is not None:
165
+ self.progress_fn(self.pfn_lo, self.pfn_total)
166
+
167
+ def update(self, cur_items):
168
+ assert (self.num_items is None) or (cur_items <= self.num_items)
169
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
170
+ return
171
+ cur_time = time.time()
172
+ total_time = cur_time - self.start_time
173
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
174
+ if (self.verbose) and (self.tag is not None):
175
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
176
+ self.batch_time = cur_time
177
+ self.batch_items = cur_items
178
+
179
+ if (self.progress_fn is not None) and (self.num_items is not None):
180
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
181
+
182
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
183
+ return ProgressMonitor(
184
+ tag = tag,
185
+ num_items = num_items,
186
+ flush_interval = flush_interval,
187
+ verbose = self.verbose,
188
+ progress_fn = self.progress_fn,
189
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
190
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
191
+ pfn_total = self.pfn_total,
192
+ )
193
+
194
+ #----------------------------------------------------------------------------
195
+
196
+ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
197
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
198
+ if data_loader_kwargs is None:
199
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
200
+
201
+ # Try to lookup from cache.
202
+ cache_file = None
203
+ if opts.cache:
204
+ # Choose cache file name.
205
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
206
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
207
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
208
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
209
+
210
+ # Check if the file exists (all processes must agree).
211
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
212
+ if opts.num_gpus > 1:
213
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
214
+ torch.distributed.broadcast(tensor=flag, src=0)
215
+ flag = (float(flag.cpu()) != 0)
216
+
217
+ # Load.
218
+ if flag:
219
+ return FeatureStats.load(cache_file)
220
+
221
+ # Initialize.
222
+ num_items = len(dataset)
223
+ if max_items is not None:
224
+ num_items = min(num_items, max_items)
225
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
226
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
227
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
228
+
229
+ # Main loop.
230
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
231
+ for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
232
+ if images.shape[1] == 1:
233
+ images = images.repeat([1, 3, 1, 1])
234
+ features = detector(images.to(opts.device), **detector_kwargs)
235
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
236
+ progress.update(stats.num_items)
237
+
238
+ # Save to cache.
239
+ if cache_file is not None and opts.rank == 0:
240
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
241
+ temp_file = cache_file + '.' + uuid.uuid4().hex
242
+ stats.save(temp_file)
243
+ os.replace(temp_file, cache_file) # atomic
244
+ return stats
245
+
246
+ #----------------------------------------------------------------------------
247
+
248
+ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
249
+ if batch_gen is None:
250
+ batch_gen = min(batch_size, 4)
251
+ assert batch_size % batch_gen == 0
252
+
253
+ # Setup generator and labels.
254
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
255
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
256
+
257
+ # Initialize.
258
+ stats = FeatureStats(**stats_kwargs)
259
+ assert stats.max_items is not None
260
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
261
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
262
+
263
+ # Main loop.
264
+ while not stats.is_full():
265
+ images = []
266
+ for _i in range(batch_size // batch_gen):
267
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
268
+ img = G(z=z, c=next(c_iter), **opts.G_kwargs)
269
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
270
+ images.append(img)
271
+ images = torch.cat(images)
272
+ if images.shape[1] == 1:
273
+ images = images.repeat([1, 3, 1, 1])
274
+ features = detector(images, **detector_kwargs)
275
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
276
+ progress.update(stats.num_items)
277
+ return stats
278
+
279
+ #----------------------------------------------------------------------------
stylegan3-fun/metrics/perceptual_path_length.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
+ Architecture for Generative Adversarial Networks". Matches the original
11
+ implementation by Karras et al. at
12
+ https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
+
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ from . import metric_utils
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ # Spherical interpolation of a batch of vectors.
22
+ def slerp(a, b, t):
23
+ a = a / a.norm(dim=-1, keepdim=True)
24
+ b = b / b.norm(dim=-1, keepdim=True)
25
+ d = (a * b).sum(dim=-1, keepdim=True)
26
+ p = t * torch.acos(d)
27
+ c = b - d * a
28
+ c = c / c.norm(dim=-1, keepdim=True)
29
+ d = a * torch.cos(p) + c * torch.sin(p)
30
+ d = d / d.norm(dim=-1, keepdim=True)
31
+ return d
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ class PPLSampler(torch.nn.Module):
36
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
37
+ assert space in ['z', 'w']
38
+ assert sampling in ['full', 'end']
39
+ super().__init__()
40
+ self.G = copy.deepcopy(G)
41
+ self.G_kwargs = G_kwargs
42
+ self.epsilon = epsilon
43
+ self.space = space
44
+ self.sampling = sampling
45
+ self.crop = crop
46
+ self.vgg16 = copy.deepcopy(vgg16)
47
+
48
+ def forward(self, c):
49
+ # Generate random latents and interpolation t-values.
50
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
51
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
52
+
53
+ # Interpolate in W or Z.
54
+ if self.space == 'w':
55
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
56
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
57
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
58
+ else: # space == 'z'
59
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
60
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
61
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
62
+
63
+ # Randomize noise buffers.
64
+ for name, buf in self.G.named_buffers():
65
+ if name.endswith('.noise_const'):
66
+ buf.copy_(torch.randn_like(buf))
67
+
68
+ # Generate images.
69
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
70
+
71
+ # Center crop.
72
+ if self.crop:
73
+ assert img.shape[2] == img.shape[3]
74
+ c = img.shape[2] // 8
75
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
76
+
77
+ # Downsample to 256x256.
78
+ factor = self.G.img_resolution // 256
79
+ if factor > 1:
80
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
81
+
82
+ # Scale dynamic range from [-1,1] to [0,255].
83
+ img = (img + 1) * (255 / 2)
84
+ if self.G.img_channels == 1:
85
+ img = img.repeat([1, 3, 1, 1])
86
+
87
+ # Evaluate differential LPIPS.
88
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
89
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
90
+ return dist
91
+
92
+ #----------------------------------------------------------------------------
93
+
94
+ def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
95
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
96
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
97
+
98
+ # Setup sampler and labels.
99
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
100
+ sampler.eval().requires_grad_(False).to(opts.device)
101
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
102
+
103
+ # Sampling loop.
104
+ dist = []
105
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
106
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
107
+ progress.update(batch_start)
108
+ x = sampler(next(c_iter))
109
+ for src in range(opts.num_gpus):
110
+ y = x.clone()
111
+ if opts.num_gpus > 1:
112
+ torch.distributed.broadcast(y, src=src)
113
+ dist.append(y)
114
+ progress.update(num_samples)
115
+
116
+ # Compute PPL.
117
+ if opts.rank != 0:
118
+ return float('nan')
119
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
120
+ lo = np.percentile(dist, 1, interpolation='lower')
121
+ hi = np.percentile(dist, 99, interpolation='higher')
122
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
123
+ return float(ppl)
124
+
125
+ #----------------------------------------------------------------------------