Spaces:
Runtime error
Runtime error
rossellison
commited on
Commit
•
8a860df
1
Parent(s):
181927b
Upload 159 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +87 -0
- requirements.txt +17 -0
- stylegan3-fun/.github/FUNDING.yml +3 -0
- stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md +35 -0
- stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
- stylegan3-fun/.gitignore +173 -0
- stylegan3-fun/__pycache__/legacy.cpython-311.pyc +0 -0
- stylegan3-fun/__pycache__/legacy.cpython-38.pyc +0 -0
- stylegan3-fun/__pycache__/legacy.cpython-39.pyc +0 -0
- stylegan3-fun/dnnlib/__init__.py +9 -0
- stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc +0 -0
- stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc +0 -0
- stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
- stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc +0 -0
- stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc +0 -0
- stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc +0 -0
- stylegan3-fun/dnnlib/util.py +491 -0
- stylegan3-fun/gui_utils/__init__.py +9 -0
- stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-311.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-38.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-39.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-311.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-38.pyc +0 -0
- stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-39.pyc +0 -0
- stylegan3-fun/gui_utils/gl_utils.py +374 -0
- stylegan3-fun/gui_utils/glfw_window.py +229 -0
- stylegan3-fun/gui_utils/imgui_utils.py +169 -0
- stylegan3-fun/gui_utils/imgui_window.py +103 -0
- stylegan3-fun/gui_utils/text_utils.py +123 -0
- stylegan3-fun/legacy.py +346 -0
- stylegan3-fun/metrics/__init__.py +9 -0
- stylegan3-fun/metrics/equivariance.py +267 -0
- stylegan3-fun/metrics/frechet_inception_distance.py +41 -0
- stylegan3-fun/metrics/inception_score.py +38 -0
- stylegan3-fun/metrics/kernel_inception_distance.py +46 -0
- stylegan3-fun/metrics/metric_main.py +153 -0
- stylegan3-fun/metrics/metric_utils.py +279 -0
- stylegan3-fun/metrics/perceptual_path_length.py +125 -0
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import random
|
6 |
+
|
7 |
+
sys.path.append('stylegan3-fun') # change this to the path where dnnlib is located
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import PIL.Image
|
11 |
+
import torch
|
12 |
+
import streamlit as st
|
13 |
+
import dnnlib
|
14 |
+
import legacy
|
15 |
+
|
16 |
+
|
17 |
+
def parse_range(s: Union[str, List]) -> List[int]:
|
18 |
+
'''Parse a comma separated list of numbers or ranges and return a list of ints.
|
19 |
+
|
20 |
+
Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
|
21 |
+
'''
|
22 |
+
if isinstance(s, list): return s
|
23 |
+
ranges = []
|
24 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
25 |
+
for p in s.split(','):
|
26 |
+
m = range_re.match(p)
|
27 |
+
if m:
|
28 |
+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
|
29 |
+
else:
|
30 |
+
ranges.append(int(p))
|
31 |
+
return ranges
|
32 |
+
|
33 |
+
def make_transform(translate: Tuple[float,float], angle: float):
|
34 |
+
m = np.eye(3)
|
35 |
+
s = np.sin(angle/360.0*np.pi*2)
|
36 |
+
c = np.cos(angle/360.0*np.pi*2)
|
37 |
+
m[0][0] = c
|
38 |
+
m[0][1] = s
|
39 |
+
m[0][2] = translate[0]
|
40 |
+
m[1][0] = -s
|
41 |
+
m[1][1] = c
|
42 |
+
m[1][2] = translate[1]
|
43 |
+
return m
|
44 |
+
|
45 |
+
def generate_image(network_pkl: str, seed: int, truncation_psi: float, noise_mode: str, translate: Tuple[float,float], rotate: float, class_idx: Optional[int]):
|
46 |
+
print('Loading networks from "%s"...' % network_pkl)
|
47 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
48 |
+
with open(network_pkl, 'rb') as f:
|
49 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
50 |
+
|
51 |
+
# Labels.
|
52 |
+
label = torch.zeros([1, G.c_dim], device=device)
|
53 |
+
if G.c_dim != 0:
|
54 |
+
if class_idx is None:
|
55 |
+
raise Exception('Must specify class label when using a conditional network')
|
56 |
+
label[:, class_idx] = 1
|
57 |
+
|
58 |
+
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
|
59 |
+
|
60 |
+
if hasattr(G.synthesis, 'input'):
|
61 |
+
m = make_transform(translate, rotate)
|
62 |
+
m = np.linalg.inv(m)
|
63 |
+
G.synthesis.input.transform.copy_(torch.from_numpy(m))
|
64 |
+
|
65 |
+
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
|
66 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
67 |
+
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
68 |
+
return img
|
69 |
+
|
70 |
+
def main():
|
71 |
+
st.title('Kpop Face Generator')
|
72 |
+
|
73 |
+
st.write('Press the button below to generate a new image:')
|
74 |
+
if st.button('Generate'):
|
75 |
+
network_pkl = 'kpopGG.pkl'
|
76 |
+
seed = random.randint(0, 99999)
|
77 |
+
truncation_psi = 0.45
|
78 |
+
noise_mode = 'const'
|
79 |
+
translate = (0.0, 0.0)
|
80 |
+
rotate = 0.0
|
81 |
+
class_idx = None
|
82 |
+
|
83 |
+
image = generate_image(network_pkl, seed, truncation_psi, noise_mode, translate, rotate, class_idx)
|
84 |
+
st.image(image)
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
click==8.1.3
|
2 |
+
ffmpeg_python==0.2.0
|
3 |
+
glfw==2.2.0
|
4 |
+
imgui==1.3.0
|
5 |
+
matplotlib==3.4.2
|
6 |
+
numpy==1.21.6
|
7 |
+
Pillow==9.5.0
|
8 |
+
Pillow==10.0.0
|
9 |
+
PyOpenGL==3.1.5
|
10 |
+
PyOpenGL==3.1.7
|
11 |
+
Requests==2.31.0
|
12 |
+
scipy==1.7.3
|
13 |
+
setuptools==58.1.0
|
14 |
+
setuptools==60.2.0
|
15 |
+
streamlit==1.25.0
|
16 |
+
torch==2.0.1
|
17 |
+
training==0.0.1
|
stylegan3-fun/.github/FUNDING.yml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# These are supported funding model platforms
|
2 |
+
|
3 |
+
github: PDillis # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
stylegan3-fun/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Bug report
|
3 |
+
about: Create a report to help us improve
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Describe the bug**
|
11 |
+
A clear and concise description of what the bug is.
|
12 |
+
|
13 |
+
**To Reproduce**
|
14 |
+
Steps to reproduce the behavior:
|
15 |
+
1. In '...' directory, run command '...'
|
16 |
+
2. See error (copy&paste full log, including exceptions and **stacktraces**).
|
17 |
+
|
18 |
+
Please copy&paste text instead of screenshots for better searchability.
|
19 |
+
|
20 |
+
**Expected behavior**
|
21 |
+
A clear and concise description of what you expected to happen.
|
22 |
+
|
23 |
+
**Screenshots**
|
24 |
+
If applicable, add screenshots to help explain your problem.
|
25 |
+
|
26 |
+
**Desktop (please complete the following information):**
|
27 |
+
- OS: [e.g. Linux Ubuntu 20.04, Windows 10]
|
28 |
+
- PyTorch version (e.g., pytorch 1.9.0)
|
29 |
+
- CUDA toolkit version (e.g., CUDA 11.4)
|
30 |
+
- NVIDIA driver version
|
31 |
+
- GPU [e.g., Titan V, RTX 3090]
|
32 |
+
- Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:21.08-py3)
|
33 |
+
|
34 |
+
**Additional context**
|
35 |
+
Add any other context about the problem here.
|
stylegan3-fun/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Feature request
|
3 |
+
about: Suggest an idea for this project
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Is your feature request related to a problem? Please describe.**
|
11 |
+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
12 |
+
|
13 |
+
**Describe the solution you'd like**
|
14 |
+
A clear and concise description of what you want to happen.
|
15 |
+
|
16 |
+
**Describe alternatives you've considered**
|
17 |
+
A clear and concise description of any alternative solutions or features you've considered.
|
18 |
+
|
19 |
+
**Additional context**
|
20 |
+
Add any other context or screenshots about the feature request here.
|
stylegan3-fun/.gitignore
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
*/**/__pycache__/
|
3 |
+
__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/#use-with-ide
|
111 |
+
.pdm.toml
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
.idea/
|
162 |
+
|
163 |
+
# Conda temp
|
164 |
+
.condatmp/
|
165 |
+
|
166 |
+
# SGAN specific folders
|
167 |
+
datasets/
|
168 |
+
dlatents/
|
169 |
+
out/
|
170 |
+
training-runs/
|
171 |
+
pretrained/
|
172 |
+
video/
|
173 |
+
_screenshots/
|
stylegan3-fun/__pycache__/legacy.cpython-311.pyc
ADDED
Binary file (29 kB). View file
|
|
stylegan3-fun/__pycache__/legacy.cpython-38.pyc
ADDED
Binary file (15.4 kB). View file
|
|
stylegan3-fun/__pycache__/legacy.cpython-39.pyc
ADDED
Binary file (15.3 kB). View file
|
|
stylegan3-fun/dnnlib/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .util import EasyDict, make_cache_dir_path
|
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (285 Bytes). View file
|
|
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (245 Bytes). View file
|
|
stylegan3-fun/dnnlib/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (245 Bytes). View file
|
|
stylegan3-fun/dnnlib/__pycache__/util.cpython-311.pyc
ADDED
Binary file (26.1 kB). View file
|
|
stylegan3-fun/dnnlib/__pycache__/util.cpython-38.pyc
ADDED
Binary file (14.1 kB). View file
|
|
stylegan3-fun/dnnlib/__pycache__/util.cpython-39.pyc
ADDED
Binary file (14.1 kB). View file
|
|
stylegan3-fun/dnnlib/util.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
|
11 |
+
import ctypes
|
12 |
+
import fnmatch
|
13 |
+
import importlib
|
14 |
+
import inspect
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import sys
|
19 |
+
import types
|
20 |
+
import io
|
21 |
+
import pickle
|
22 |
+
import re
|
23 |
+
import requests
|
24 |
+
import html
|
25 |
+
import hashlib
|
26 |
+
import glob
|
27 |
+
import tempfile
|
28 |
+
import urllib
|
29 |
+
import urllib.request
|
30 |
+
import uuid
|
31 |
+
|
32 |
+
from distutils.util import strtobool
|
33 |
+
from typing import Any, List, Tuple, Union
|
34 |
+
|
35 |
+
|
36 |
+
# Util classes
|
37 |
+
# ------------------------------------------------------------------------------------------
|
38 |
+
|
39 |
+
|
40 |
+
class EasyDict(dict):
|
41 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
42 |
+
|
43 |
+
def __getattr__(self, name: str) -> Any:
|
44 |
+
try:
|
45 |
+
return self[name]
|
46 |
+
except KeyError:
|
47 |
+
raise AttributeError(name)
|
48 |
+
|
49 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
50 |
+
self[name] = value
|
51 |
+
|
52 |
+
def __delattr__(self, name: str) -> None:
|
53 |
+
del self[name]
|
54 |
+
|
55 |
+
|
56 |
+
class Logger(object):
|
57 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
58 |
+
|
59 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
60 |
+
self.file = None
|
61 |
+
|
62 |
+
if file_name is not None:
|
63 |
+
self.file = open(file_name, file_mode)
|
64 |
+
|
65 |
+
self.should_flush = should_flush
|
66 |
+
self.stdout = sys.stdout
|
67 |
+
self.stderr = sys.stderr
|
68 |
+
|
69 |
+
sys.stdout = self
|
70 |
+
sys.stderr = self
|
71 |
+
|
72 |
+
def __enter__(self) -> "Logger":
|
73 |
+
return self
|
74 |
+
|
75 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
76 |
+
self.close()
|
77 |
+
|
78 |
+
def write(self, text: Union[str, bytes]) -> None:
|
79 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
80 |
+
if isinstance(text, bytes):
|
81 |
+
text = text.decode()
|
82 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
83 |
+
return
|
84 |
+
|
85 |
+
if self.file is not None:
|
86 |
+
self.file.write(text)
|
87 |
+
|
88 |
+
self.stdout.write(text)
|
89 |
+
|
90 |
+
if self.should_flush:
|
91 |
+
self.flush()
|
92 |
+
|
93 |
+
def flush(self) -> None:
|
94 |
+
"""Flush written text to both stdout and a file, if open."""
|
95 |
+
if self.file is not None:
|
96 |
+
self.file.flush()
|
97 |
+
|
98 |
+
self.stdout.flush()
|
99 |
+
|
100 |
+
def close(self) -> None:
|
101 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
102 |
+
self.flush()
|
103 |
+
|
104 |
+
# if using multiple loggers, prevent closing in wrong order
|
105 |
+
if sys.stdout is self:
|
106 |
+
sys.stdout = self.stdout
|
107 |
+
if sys.stderr is self:
|
108 |
+
sys.stderr = self.stderr
|
109 |
+
|
110 |
+
if self.file is not None:
|
111 |
+
self.file.close()
|
112 |
+
self.file = None
|
113 |
+
|
114 |
+
|
115 |
+
# Cache directories
|
116 |
+
# ------------------------------------------------------------------------------------------
|
117 |
+
|
118 |
+
_dnnlib_cache_dir = None
|
119 |
+
|
120 |
+
def set_cache_dir(path: str) -> None:
|
121 |
+
global _dnnlib_cache_dir
|
122 |
+
_dnnlib_cache_dir = path
|
123 |
+
|
124 |
+
def make_cache_dir_path(*paths: str) -> str:
|
125 |
+
if _dnnlib_cache_dir is not None:
|
126 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
127 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
128 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
129 |
+
if 'HOME' in os.environ:
|
130 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
131 |
+
if 'USERPROFILE' in os.environ:
|
132 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
133 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
134 |
+
|
135 |
+
# Small util functions
|
136 |
+
# ------------------------------------------------------------------------------------------
|
137 |
+
|
138 |
+
|
139 |
+
def format_time(seconds: Union[int, float]) -> str:
|
140 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
141 |
+
s = int(np.rint(seconds))
|
142 |
+
|
143 |
+
if s < 60:
|
144 |
+
return "{0}s".format(s)
|
145 |
+
elif s < 60 * 60:
|
146 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
147 |
+
elif s < 24 * 60 * 60:
|
148 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
149 |
+
else:
|
150 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
151 |
+
|
152 |
+
|
153 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
154 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
155 |
+
s = int(np.rint(seconds))
|
156 |
+
|
157 |
+
if s < 60:
|
158 |
+
return "{0}s".format(s)
|
159 |
+
elif s < 60 * 60:
|
160 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
161 |
+
elif s < 24 * 60 * 60:
|
162 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
163 |
+
else:
|
164 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
165 |
+
|
166 |
+
|
167 |
+
def ask_yes_no(question: str) -> bool:
|
168 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
169 |
+
while True:
|
170 |
+
try:
|
171 |
+
print("{0} [y/n]".format(question))
|
172 |
+
return strtobool(input().lower())
|
173 |
+
except ValueError:
|
174 |
+
pass
|
175 |
+
|
176 |
+
|
177 |
+
def tuple_product(t: Tuple) -> Any:
|
178 |
+
"""Calculate the product of the tuple elements."""
|
179 |
+
result = 1
|
180 |
+
|
181 |
+
for v in t:
|
182 |
+
result *= v
|
183 |
+
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
_str_to_ctype = {
|
188 |
+
"uint8": ctypes.c_ubyte,
|
189 |
+
"uint16": ctypes.c_uint16,
|
190 |
+
"uint32": ctypes.c_uint32,
|
191 |
+
"uint64": ctypes.c_uint64,
|
192 |
+
"int8": ctypes.c_byte,
|
193 |
+
"int16": ctypes.c_int16,
|
194 |
+
"int32": ctypes.c_int32,
|
195 |
+
"int64": ctypes.c_int64,
|
196 |
+
"float32": ctypes.c_float,
|
197 |
+
"float64": ctypes.c_double
|
198 |
+
}
|
199 |
+
|
200 |
+
|
201 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
202 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
203 |
+
type_str = None
|
204 |
+
|
205 |
+
if isinstance(type_obj, str):
|
206 |
+
type_str = type_obj
|
207 |
+
elif hasattr(type_obj, "__name__"):
|
208 |
+
type_str = type_obj.__name__
|
209 |
+
elif hasattr(type_obj, "name"):
|
210 |
+
type_str = type_obj.name
|
211 |
+
else:
|
212 |
+
raise RuntimeError("Cannot infer type name from input")
|
213 |
+
|
214 |
+
assert type_str in _str_to_ctype.keys()
|
215 |
+
|
216 |
+
my_dtype = np.dtype(type_str)
|
217 |
+
my_ctype = _str_to_ctype[type_str]
|
218 |
+
|
219 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
220 |
+
|
221 |
+
return my_dtype, my_ctype
|
222 |
+
|
223 |
+
|
224 |
+
def is_pickleable(obj: Any) -> bool:
|
225 |
+
try:
|
226 |
+
with io.BytesIO() as stream:
|
227 |
+
pickle.dump(obj, stream)
|
228 |
+
return True
|
229 |
+
except:
|
230 |
+
return False
|
231 |
+
|
232 |
+
|
233 |
+
# Functionality to import modules/objects by name, and call functions by name
|
234 |
+
# ------------------------------------------------------------------------------------------
|
235 |
+
|
236 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
237 |
+
"""Searches for the underlying module behind the name to some python object.
|
238 |
+
Returns the module and the object name (original name with module part removed)."""
|
239 |
+
|
240 |
+
# allow convenience shorthands, substitute them by full names
|
241 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
242 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
243 |
+
|
244 |
+
# list alternatives for (module_name, local_obj_name)
|
245 |
+
parts = obj_name.split(".")
|
246 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
247 |
+
|
248 |
+
# try each alternative in turn
|
249 |
+
for module_name, local_obj_name in name_pairs:
|
250 |
+
try:
|
251 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
252 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
253 |
+
return module, local_obj_name
|
254 |
+
except:
|
255 |
+
pass
|
256 |
+
|
257 |
+
# maybe some of the modules themselves contain errors?
|
258 |
+
for module_name, _local_obj_name in name_pairs:
|
259 |
+
try:
|
260 |
+
importlib.import_module(module_name) # may raise ImportError
|
261 |
+
except ImportError:
|
262 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
263 |
+
raise
|
264 |
+
|
265 |
+
# maybe the requested attribute is missing?
|
266 |
+
for module_name, local_obj_name in name_pairs:
|
267 |
+
try:
|
268 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
269 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
270 |
+
except ImportError:
|
271 |
+
pass
|
272 |
+
|
273 |
+
# we are out of luck, but we have no idea why
|
274 |
+
raise ImportError(obj_name)
|
275 |
+
|
276 |
+
|
277 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
278 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
279 |
+
if obj_name == '':
|
280 |
+
return module
|
281 |
+
obj = module
|
282 |
+
for part in obj_name.split("."):
|
283 |
+
obj = getattr(obj, part)
|
284 |
+
return obj
|
285 |
+
|
286 |
+
|
287 |
+
def get_obj_by_name(name: str) -> Any:
|
288 |
+
"""Finds the python object with the given name."""
|
289 |
+
module, obj_name = get_module_from_obj_name(name)
|
290 |
+
return get_obj_from_module(module, obj_name)
|
291 |
+
|
292 |
+
|
293 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
294 |
+
"""Finds the python object with the given name and calls it as a function."""
|
295 |
+
assert func_name is not None
|
296 |
+
func_obj = get_obj_by_name(func_name)
|
297 |
+
assert callable(func_obj)
|
298 |
+
return func_obj(*args, **kwargs)
|
299 |
+
|
300 |
+
|
301 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
302 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
303 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
304 |
+
|
305 |
+
|
306 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
307 |
+
"""Get the directory path of the module containing the given object name."""
|
308 |
+
module, _ = get_module_from_obj_name(obj_name)
|
309 |
+
return os.path.dirname(inspect.getfile(module))
|
310 |
+
|
311 |
+
|
312 |
+
def is_top_level_function(obj: Any) -> bool:
|
313 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
314 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
315 |
+
|
316 |
+
|
317 |
+
def get_top_level_function_name(obj: Any) -> str:
|
318 |
+
"""Return the fully-qualified name of a top-level function."""
|
319 |
+
assert is_top_level_function(obj)
|
320 |
+
module = obj.__module__
|
321 |
+
if module == '__main__':
|
322 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
323 |
+
return module + "." + obj.__name__
|
324 |
+
|
325 |
+
|
326 |
+
# File system helpers
|
327 |
+
# ------------------------------------------------------------------------------------------
|
328 |
+
|
329 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
330 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
331 |
+
Returns list of tuples containing both absolute and relative paths."""
|
332 |
+
assert os.path.isdir(dir_path)
|
333 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
334 |
+
|
335 |
+
if ignores is None:
|
336 |
+
ignores = []
|
337 |
+
|
338 |
+
result = []
|
339 |
+
|
340 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
341 |
+
for ignore_ in ignores:
|
342 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
343 |
+
|
344 |
+
# dirs need to be edited in-place
|
345 |
+
for d in dirs_to_remove:
|
346 |
+
dirs.remove(d)
|
347 |
+
|
348 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
349 |
+
|
350 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
351 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
352 |
+
|
353 |
+
if add_base_to_relative:
|
354 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
355 |
+
|
356 |
+
assert len(absolute_paths) == len(relative_paths)
|
357 |
+
result += zip(absolute_paths, relative_paths)
|
358 |
+
|
359 |
+
return result
|
360 |
+
|
361 |
+
|
362 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
363 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
364 |
+
Will create all necessary directories."""
|
365 |
+
for file in files:
|
366 |
+
target_dir_name = os.path.dirname(file[1])
|
367 |
+
|
368 |
+
# will create all intermediate-level directories
|
369 |
+
if not os.path.exists(target_dir_name):
|
370 |
+
os.makedirs(target_dir_name)
|
371 |
+
|
372 |
+
shutil.copyfile(file[0], file[1])
|
373 |
+
|
374 |
+
|
375 |
+
# URL helpers
|
376 |
+
# ------------------------------------------------------------------------------------------
|
377 |
+
|
378 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
379 |
+
"""Determine whether the given object is a valid URL string."""
|
380 |
+
if not isinstance(obj, str) or not "://" in obj:
|
381 |
+
return False
|
382 |
+
if allow_file_urls and obj.startswith('file://'):
|
383 |
+
return True
|
384 |
+
try:
|
385 |
+
res = requests.compat.urlparse(obj)
|
386 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
387 |
+
return False
|
388 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
389 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
390 |
+
return False
|
391 |
+
except:
|
392 |
+
return False
|
393 |
+
return True
|
394 |
+
|
395 |
+
|
396 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
397 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
398 |
+
assert num_attempts >= 1
|
399 |
+
assert not (return_filename and (not cache))
|
400 |
+
|
401 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
402 |
+
if not re.match('^[a-z]+://', url):
|
403 |
+
return url if return_filename else open(url, "rb")
|
404 |
+
|
405 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
406 |
+
# arise on Windows:
|
407 |
+
#
|
408 |
+
# file:///c:/foo.txt
|
409 |
+
#
|
410 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
411 |
+
# invalid. Drop the forward slash for such pathnames.
|
412 |
+
#
|
413 |
+
# If you touch this code path, you should test it on both Linux and
|
414 |
+
# Windows.
|
415 |
+
#
|
416 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
417 |
+
# but that converts forward slashes to backslashes and this causes
|
418 |
+
# its own set of problems.
|
419 |
+
if url.startswith('file://'):
|
420 |
+
filename = urllib.parse.urlparse(url).path
|
421 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
422 |
+
filename = filename[1:]
|
423 |
+
return filename if return_filename else open(filename, "rb")
|
424 |
+
|
425 |
+
assert is_url(url)
|
426 |
+
|
427 |
+
# Lookup from cache.
|
428 |
+
if cache_dir is None:
|
429 |
+
cache_dir = make_cache_dir_path('downloads')
|
430 |
+
|
431 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
432 |
+
if cache:
|
433 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
434 |
+
if len(cache_files) == 1:
|
435 |
+
filename = cache_files[0]
|
436 |
+
return filename if return_filename else open(filename, "rb")
|
437 |
+
|
438 |
+
# Download.
|
439 |
+
url_name = None
|
440 |
+
url_data = None
|
441 |
+
with requests.Session() as session:
|
442 |
+
if verbose:
|
443 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
444 |
+
for attempts_left in reversed(range(num_attempts)):
|
445 |
+
try:
|
446 |
+
with session.get(url) as res:
|
447 |
+
res.raise_for_status()
|
448 |
+
if len(res.content) == 0:
|
449 |
+
raise IOError("No data received")
|
450 |
+
|
451 |
+
if len(res.content) < 8192:
|
452 |
+
content_str = res.content.decode("utf-8")
|
453 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
454 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
455 |
+
if len(links) == 1:
|
456 |
+
url = requests.compat.urljoin(url, links[0])
|
457 |
+
raise IOError("Google Drive virus checker nag")
|
458 |
+
if "Google Drive - Quota exceeded" in content_str:
|
459 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
460 |
+
|
461 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
462 |
+
url_name = match[1] if match else url
|
463 |
+
url_data = res.content
|
464 |
+
if verbose:
|
465 |
+
print(" done")
|
466 |
+
break
|
467 |
+
except KeyboardInterrupt:
|
468 |
+
raise
|
469 |
+
except:
|
470 |
+
if not attempts_left:
|
471 |
+
if verbose:
|
472 |
+
print(" failed")
|
473 |
+
raise
|
474 |
+
if verbose:
|
475 |
+
print(".", end="", flush=True)
|
476 |
+
|
477 |
+
# Save to cache.
|
478 |
+
if cache:
|
479 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
480 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
481 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
482 |
+
os.makedirs(cache_dir, exist_ok=True)
|
483 |
+
with open(temp_file, "wb") as f:
|
484 |
+
f.write(url_data)
|
485 |
+
os.replace(temp_file, cache_file) # atomic
|
486 |
+
if return_filename:
|
487 |
+
return cache_file
|
488 |
+
|
489 |
+
# Return data as file object.
|
490 |
+
assert not return_filename
|
491 |
+
return io.BytesIO(url_data)
|
stylegan3-fun/gui_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (199 Bytes). View file
|
|
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (181 Bytes). View file
|
|
stylegan3-fun/gui_utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (181 Bytes). View file
|
|
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-311.pyc
ADDED
Binary file (26.9 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-38.pyc
ADDED
Binary file (12.1 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/gl_utils.cpython-39.pyc
ADDED
Binary file (12.1 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-311.pyc
ADDED
Binary file (13.2 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-38.pyc
ADDED
Binary file (7.8 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/glfw_window.cpython-39.pyc
ADDED
Binary file (7.78 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-311.pyc
ADDED
Binary file (9.73 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-38.pyc
ADDED
Binary file (5.1 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/imgui_utils.cpython-39.pyc
ADDED
Binary file (5.19 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-311.pyc
ADDED
Binary file (7.11 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-38.pyc
ADDED
Binary file (4 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/imgui_window.cpython-39.pyc
ADDED
Binary file (3.99 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-311.pyc
ADDED
Binary file (9.87 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-38.pyc
ADDED
Binary file (5.19 kB). View file
|
|
stylegan3-fun/gui_utils/__pycache__/text_utils.cpython-39.pyc
ADDED
Binary file (5.15 kB). View file
|
|
stylegan3-fun/gui_utils/gl_utils.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import functools
|
11 |
+
import contextlib
|
12 |
+
import numpy as np
|
13 |
+
import OpenGL.GL as gl
|
14 |
+
import OpenGL.GL.ARB.texture_float
|
15 |
+
import dnnlib
|
16 |
+
|
17 |
+
#----------------------------------------------------------------------------
|
18 |
+
|
19 |
+
def init_egl():
|
20 |
+
assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
|
21 |
+
import OpenGL.EGL as egl
|
22 |
+
import ctypes
|
23 |
+
|
24 |
+
# Initialize EGL.
|
25 |
+
display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
|
26 |
+
assert display != egl.EGL_NO_DISPLAY
|
27 |
+
major = ctypes.c_int32()
|
28 |
+
minor = ctypes.c_int32()
|
29 |
+
ok = egl.eglInitialize(display, major, minor)
|
30 |
+
assert ok
|
31 |
+
assert major.value * 10 + minor.value >= 14
|
32 |
+
|
33 |
+
# Choose config.
|
34 |
+
config_attribs = [
|
35 |
+
egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
|
36 |
+
egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
|
37 |
+
egl.EGL_NONE
|
38 |
+
]
|
39 |
+
configs = (ctypes.c_int32 * 1)()
|
40 |
+
num_configs = ctypes.c_int32()
|
41 |
+
ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
|
42 |
+
assert ok
|
43 |
+
assert num_configs.value == 1
|
44 |
+
config = configs[0]
|
45 |
+
|
46 |
+
# Create dummy pbuffer surface.
|
47 |
+
surface_attribs = [
|
48 |
+
egl.EGL_WIDTH, 1,
|
49 |
+
egl.EGL_HEIGHT, 1,
|
50 |
+
egl.EGL_NONE
|
51 |
+
]
|
52 |
+
surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
|
53 |
+
assert surface != egl.EGL_NO_SURFACE
|
54 |
+
|
55 |
+
# Setup GL context.
|
56 |
+
ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
|
57 |
+
assert ok
|
58 |
+
context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
|
59 |
+
assert context != egl.EGL_NO_CONTEXT
|
60 |
+
ok = egl.eglMakeCurrent(display, surface, surface, context)
|
61 |
+
assert ok
|
62 |
+
|
63 |
+
#----------------------------------------------------------------------------
|
64 |
+
|
65 |
+
_texture_formats = {
|
66 |
+
('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
|
67 |
+
('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
|
68 |
+
('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
|
69 |
+
('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
|
70 |
+
('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
|
71 |
+
('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
|
72 |
+
('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
|
73 |
+
('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
|
74 |
+
}
|
75 |
+
|
76 |
+
def get_texture_format(dtype, channels):
|
77 |
+
return _texture_formats[(np.dtype(dtype).name, int(channels))]
|
78 |
+
|
79 |
+
#----------------------------------------------------------------------------
|
80 |
+
|
81 |
+
def prepare_texture_data(image):
|
82 |
+
image = np.asarray(image)
|
83 |
+
if image.ndim == 2:
|
84 |
+
image = image[:, :, np.newaxis]
|
85 |
+
if image.dtype.name == 'float64':
|
86 |
+
image = image.astype('float32')
|
87 |
+
return image
|
88 |
+
|
89 |
+
#----------------------------------------------------------------------------
|
90 |
+
|
91 |
+
def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
|
92 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
93 |
+
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
|
94 |
+
align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
|
95 |
+
image = prepare_texture_data(image)
|
96 |
+
height, width, channels = image.shape
|
97 |
+
size = zoom * [width, height]
|
98 |
+
pos = pos - size * align
|
99 |
+
if rint:
|
100 |
+
pos = np.rint(pos)
|
101 |
+
fmt = get_texture_format(image.dtype, channels)
|
102 |
+
|
103 |
+
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
|
104 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
|
105 |
+
gl.glRasterPos2f(pos[0], pos[1])
|
106 |
+
gl.glPixelZoom(zoom[0], -zoom[1])
|
107 |
+
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
|
108 |
+
gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
|
109 |
+
gl.glPopClientAttrib()
|
110 |
+
gl.glPopAttrib()
|
111 |
+
|
112 |
+
#----------------------------------------------------------------------------
|
113 |
+
|
114 |
+
def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
|
115 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
116 |
+
dtype = np.dtype(dtype)
|
117 |
+
fmt = get_texture_format(dtype, channels)
|
118 |
+
image = np.empty([height, width, channels], dtype=dtype)
|
119 |
+
|
120 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
|
121 |
+
gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
|
122 |
+
gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
|
123 |
+
gl.glPopClientAttrib()
|
124 |
+
return np.flipud(image)
|
125 |
+
|
126 |
+
#----------------------------------------------------------------------------
|
127 |
+
|
128 |
+
class Texture:
|
129 |
+
def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
|
130 |
+
self.gl_id = None
|
131 |
+
self.bilinear = bilinear
|
132 |
+
self.mipmap = mipmap
|
133 |
+
|
134 |
+
# Determine size and dtype.
|
135 |
+
if image is not None:
|
136 |
+
image = prepare_texture_data(image)
|
137 |
+
self.height, self.width, self.channels = image.shape
|
138 |
+
self.dtype = image.dtype
|
139 |
+
else:
|
140 |
+
assert width is not None and height is not None
|
141 |
+
self.width = width
|
142 |
+
self.height = height
|
143 |
+
self.channels = channels if channels is not None else 3
|
144 |
+
self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
|
145 |
+
|
146 |
+
# Validate size and dtype.
|
147 |
+
assert isinstance(self.width, int) and self.width >= 0
|
148 |
+
assert isinstance(self.height, int) and self.height >= 0
|
149 |
+
assert isinstance(self.channels, int) and self.channels >= 1
|
150 |
+
assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
|
151 |
+
|
152 |
+
# Create texture object.
|
153 |
+
self.gl_id = gl.glGenTextures(1)
|
154 |
+
with self.bind():
|
155 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
156 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
157 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
|
158 |
+
gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
|
159 |
+
self.update(image)
|
160 |
+
|
161 |
+
def delete(self):
|
162 |
+
if self.gl_id is not None:
|
163 |
+
gl.glDeleteTextures([self.gl_id])
|
164 |
+
self.gl_id = None
|
165 |
+
|
166 |
+
def __del__(self):
|
167 |
+
try:
|
168 |
+
self.delete()
|
169 |
+
except:
|
170 |
+
pass
|
171 |
+
|
172 |
+
@contextlib.contextmanager
|
173 |
+
def bind(self):
|
174 |
+
prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
|
175 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
|
176 |
+
yield
|
177 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
|
178 |
+
|
179 |
+
def update(self, image):
|
180 |
+
if image is not None:
|
181 |
+
image = prepare_texture_data(image)
|
182 |
+
assert self.is_compatible(image=image)
|
183 |
+
with self.bind():
|
184 |
+
fmt = get_texture_format(self.dtype, self.channels)
|
185 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
|
186 |
+
gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
|
187 |
+
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
|
188 |
+
if self.mipmap:
|
189 |
+
gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
|
190 |
+
gl.glPopClientAttrib()
|
191 |
+
|
192 |
+
def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
|
193 |
+
zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
|
194 |
+
size = zoom * [self.width, self.height]
|
195 |
+
with self.bind():
|
196 |
+
gl.glPushAttrib(gl.GL_ENABLE_BIT)
|
197 |
+
gl.glEnable(gl.GL_TEXTURE_2D)
|
198 |
+
draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
|
199 |
+
gl.glPopAttrib()
|
200 |
+
|
201 |
+
def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
|
202 |
+
if image is not None:
|
203 |
+
if image.ndim != 3:
|
204 |
+
return False
|
205 |
+
ih, iw, ic = image.shape
|
206 |
+
if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
|
207 |
+
return False
|
208 |
+
if width is not None and self.width != width:
|
209 |
+
return False
|
210 |
+
if height is not None and self.height != height:
|
211 |
+
return False
|
212 |
+
if channels is not None and self.channels != channels:
|
213 |
+
return False
|
214 |
+
if dtype is not None and self.dtype != dtype:
|
215 |
+
return False
|
216 |
+
return True
|
217 |
+
|
218 |
+
#----------------------------------------------------------------------------
|
219 |
+
|
220 |
+
class Framebuffer:
|
221 |
+
def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
|
222 |
+
self.texture = texture
|
223 |
+
self.gl_id = None
|
224 |
+
self.gl_color = None
|
225 |
+
self.gl_depth_stencil = None
|
226 |
+
self.msaa = msaa
|
227 |
+
|
228 |
+
# Determine size and dtype.
|
229 |
+
if texture is not None:
|
230 |
+
assert isinstance(self.texture, Texture)
|
231 |
+
self.width = texture.width
|
232 |
+
self.height = texture.height
|
233 |
+
self.channels = texture.channels
|
234 |
+
self.dtype = texture.dtype
|
235 |
+
else:
|
236 |
+
assert width is not None and height is not None
|
237 |
+
self.width = width
|
238 |
+
self.height = height
|
239 |
+
self.channels = channels if channels is not None else 4
|
240 |
+
self.dtype = np.dtype(dtype) if dtype is not None else np.float32
|
241 |
+
|
242 |
+
# Validate size and dtype.
|
243 |
+
assert isinstance(self.width, int) and self.width >= 0
|
244 |
+
assert isinstance(self.height, int) and self.height >= 0
|
245 |
+
assert isinstance(self.channels, int) and self.channels >= 1
|
246 |
+
assert width is None or width == self.width
|
247 |
+
assert height is None or height == self.height
|
248 |
+
assert channels is None or channels == self.channels
|
249 |
+
assert dtype is None or dtype == self.dtype
|
250 |
+
|
251 |
+
# Create framebuffer object.
|
252 |
+
self.gl_id = gl.glGenFramebuffers(1)
|
253 |
+
with self.bind():
|
254 |
+
|
255 |
+
# Setup color buffer.
|
256 |
+
if self.texture is not None:
|
257 |
+
assert self.msaa == 0
|
258 |
+
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
|
259 |
+
else:
|
260 |
+
fmt = get_texture_format(self.dtype, self.channels)
|
261 |
+
self.gl_color = gl.glGenRenderbuffers(1)
|
262 |
+
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
|
263 |
+
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
|
264 |
+
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
|
265 |
+
|
266 |
+
# Setup depth/stencil buffer.
|
267 |
+
self.gl_depth_stencil = gl.glGenRenderbuffers(1)
|
268 |
+
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
|
269 |
+
gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
|
270 |
+
gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
|
271 |
+
|
272 |
+
def delete(self):
|
273 |
+
if self.gl_id is not None:
|
274 |
+
gl.glDeleteFramebuffers([self.gl_id])
|
275 |
+
self.gl_id = None
|
276 |
+
if self.gl_color is not None:
|
277 |
+
gl.glDeleteRenderbuffers(1, [self.gl_color])
|
278 |
+
self.gl_color = None
|
279 |
+
if self.gl_depth_stencil is not None:
|
280 |
+
gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
|
281 |
+
self.gl_depth_stencil = None
|
282 |
+
|
283 |
+
def __del__(self):
|
284 |
+
try:
|
285 |
+
self.delete()
|
286 |
+
except:
|
287 |
+
pass
|
288 |
+
|
289 |
+
@contextlib.contextmanager
|
290 |
+
def bind(self):
|
291 |
+
prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
|
292 |
+
prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
|
293 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
|
294 |
+
if self.width is not None and self.height is not None:
|
295 |
+
gl.glViewport(0, 0, self.width, self.height)
|
296 |
+
yield
|
297 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
|
298 |
+
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
|
299 |
+
|
300 |
+
def blit(self, dst=None):
|
301 |
+
assert dst is None or isinstance(dst, Framebuffer)
|
302 |
+
with self.bind():
|
303 |
+
gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
|
304 |
+
gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
|
305 |
+
|
306 |
+
#----------------------------------------------------------------------------
|
307 |
+
|
308 |
+
def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
|
309 |
+
assert vertices.ndim == 2 and vertices.shape[1] == 2
|
310 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
311 |
+
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
|
312 |
+
color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
|
313 |
+
alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
|
314 |
+
|
315 |
+
gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
|
316 |
+
gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
|
317 |
+
gl.glMatrixMode(gl.GL_MODELVIEW)
|
318 |
+
gl.glPushMatrix()
|
319 |
+
|
320 |
+
gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
|
321 |
+
gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
|
322 |
+
gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
|
323 |
+
gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
|
324 |
+
gl.glTranslate(pos[0], pos[1], 0)
|
325 |
+
gl.glScale(size[0], size[1], 1)
|
326 |
+
gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
|
327 |
+
gl.glDrawArrays(mode, 0, vertices.shape[0])
|
328 |
+
|
329 |
+
gl.glPopMatrix()
|
330 |
+
gl.glPopAttrib()
|
331 |
+
gl.glPopClientAttrib()
|
332 |
+
|
333 |
+
#----------------------------------------------------------------------------
|
334 |
+
|
335 |
+
def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
|
336 |
+
assert pos2 is None or size is None
|
337 |
+
pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
|
338 |
+
pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
|
339 |
+
size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
|
340 |
+
size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
|
341 |
+
pos = pos - size * align
|
342 |
+
if rint:
|
343 |
+
pos = np.rint(pos)
|
344 |
+
rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
|
345 |
+
rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
|
346 |
+
if np.min(rounding) == 0:
|
347 |
+
rounding *= 0
|
348 |
+
vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
|
349 |
+
draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
|
350 |
+
|
351 |
+
@functools.lru_cache(maxsize=10000)
|
352 |
+
def _setup_rect(rx, ry):
|
353 |
+
t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
|
354 |
+
s = 1 - np.sin(t); c = 1 - np.cos(t)
|
355 |
+
x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
|
356 |
+
y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
|
357 |
+
v = np.stack([x, y], axis=-1).reshape(-1, 2)
|
358 |
+
return v.astype('float32')
|
359 |
+
|
360 |
+
#----------------------------------------------------------------------------
|
361 |
+
|
362 |
+
def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
|
363 |
+
hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
|
364 |
+
vertices = _setup_circle(float(hole))
|
365 |
+
draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
|
366 |
+
|
367 |
+
@functools.lru_cache(maxsize=10000)
|
368 |
+
def _setup_circle(hole):
|
369 |
+
t = np.linspace(0, np.pi * 2, 128)
|
370 |
+
s = np.sin(t); c = np.cos(t)
|
371 |
+
v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
|
372 |
+
return v.astype('float32')
|
373 |
+
|
374 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/gui_utils/glfw_window.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import time
|
10 |
+
import glfw
|
11 |
+
import OpenGL.GL as gl
|
12 |
+
from . import gl_utils
|
13 |
+
|
14 |
+
#----------------------------------------------------------------------------
|
15 |
+
|
16 |
+
class GlfwWindow: # pylint: disable=too-many-public-methods
|
17 |
+
def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
|
18 |
+
self._glfw_window = None
|
19 |
+
self._drawing_frame = False
|
20 |
+
self._frame_start_time = None
|
21 |
+
self._frame_delta = 0
|
22 |
+
self._fps_limit = None
|
23 |
+
self._vsync = None
|
24 |
+
self._skip_frames = 0
|
25 |
+
self._deferred_show = deferred_show
|
26 |
+
self._close_on_esc = close_on_esc
|
27 |
+
self._esc_pressed = False
|
28 |
+
self._drag_and_drop_paths = None
|
29 |
+
self._capture_next_frame = False
|
30 |
+
self._captured_frame = None
|
31 |
+
|
32 |
+
# Create window.
|
33 |
+
glfw.init()
|
34 |
+
glfw.window_hint(glfw.VISIBLE, False)
|
35 |
+
self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
|
36 |
+
self._attach_glfw_callbacks()
|
37 |
+
self.make_context_current()
|
38 |
+
|
39 |
+
# Adjust window.
|
40 |
+
self.set_vsync(False)
|
41 |
+
self.set_window_size(window_width, window_height)
|
42 |
+
if not self._deferred_show:
|
43 |
+
glfw.show_window(self._glfw_window)
|
44 |
+
|
45 |
+
def close(self):
|
46 |
+
if self._drawing_frame:
|
47 |
+
self.end_frame()
|
48 |
+
if self._glfw_window is not None:
|
49 |
+
glfw.destroy_window(self._glfw_window)
|
50 |
+
self._glfw_window = None
|
51 |
+
#glfw.terminate() # Commented out to play it nice with other glfw clients.
|
52 |
+
|
53 |
+
def __del__(self):
|
54 |
+
try:
|
55 |
+
self.close()
|
56 |
+
except:
|
57 |
+
pass
|
58 |
+
|
59 |
+
@property
|
60 |
+
def window_width(self):
|
61 |
+
return self.content_width
|
62 |
+
|
63 |
+
@property
|
64 |
+
def window_height(self):
|
65 |
+
return self.content_height + self.title_bar_height
|
66 |
+
|
67 |
+
@property
|
68 |
+
def content_width(self):
|
69 |
+
width, _height = glfw.get_window_size(self._glfw_window)
|
70 |
+
return width
|
71 |
+
|
72 |
+
@property
|
73 |
+
def content_height(self):
|
74 |
+
_width, height = glfw.get_window_size(self._glfw_window)
|
75 |
+
return height
|
76 |
+
|
77 |
+
@property
|
78 |
+
def title_bar_height(self):
|
79 |
+
_left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
|
80 |
+
return top
|
81 |
+
|
82 |
+
@property
|
83 |
+
def monitor_width(self):
|
84 |
+
_, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
|
85 |
+
return width
|
86 |
+
|
87 |
+
@property
|
88 |
+
def monitor_height(self):
|
89 |
+
_, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
|
90 |
+
return height
|
91 |
+
|
92 |
+
@property
|
93 |
+
def frame_delta(self):
|
94 |
+
return self._frame_delta
|
95 |
+
|
96 |
+
def set_title(self, title):
|
97 |
+
glfw.set_window_title(self._glfw_window, title)
|
98 |
+
|
99 |
+
def set_window_size(self, width, height):
|
100 |
+
width = min(width, self.monitor_width)
|
101 |
+
height = min(height, self.monitor_height)
|
102 |
+
glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
|
103 |
+
if width == self.monitor_width and height == self.monitor_height:
|
104 |
+
self.maximize()
|
105 |
+
|
106 |
+
def set_content_size(self, width, height):
|
107 |
+
self.set_window_size(width, height + self.title_bar_height)
|
108 |
+
|
109 |
+
def maximize(self):
|
110 |
+
glfw.maximize_window(self._glfw_window)
|
111 |
+
|
112 |
+
def set_position(self, x, y):
|
113 |
+
glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
|
114 |
+
|
115 |
+
def center(self):
|
116 |
+
self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
|
117 |
+
|
118 |
+
def set_vsync(self, vsync):
|
119 |
+
vsync = bool(vsync)
|
120 |
+
if vsync != self._vsync:
|
121 |
+
glfw.swap_interval(1 if vsync else 0)
|
122 |
+
self._vsync = vsync
|
123 |
+
|
124 |
+
def set_fps_limit(self, fps_limit):
|
125 |
+
self._fps_limit = int(fps_limit)
|
126 |
+
|
127 |
+
def should_close(self):
|
128 |
+
return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
|
129 |
+
|
130 |
+
def skip_frame(self):
|
131 |
+
self.skip_frames(1)
|
132 |
+
|
133 |
+
def skip_frames(self, num): # Do not update window for the next N frames.
|
134 |
+
self._skip_frames = max(self._skip_frames, int(num))
|
135 |
+
|
136 |
+
def is_skipping_frames(self):
|
137 |
+
return self._skip_frames > 0
|
138 |
+
|
139 |
+
def capture_next_frame(self):
|
140 |
+
self._capture_next_frame = True
|
141 |
+
|
142 |
+
def pop_captured_frame(self):
|
143 |
+
frame = self._captured_frame
|
144 |
+
self._captured_frame = None
|
145 |
+
return frame
|
146 |
+
|
147 |
+
def pop_drag_and_drop_paths(self):
|
148 |
+
paths = self._drag_and_drop_paths
|
149 |
+
self._drag_and_drop_paths = None
|
150 |
+
return paths
|
151 |
+
|
152 |
+
def draw_frame(self): # To be overridden by subclass.
|
153 |
+
self.begin_frame()
|
154 |
+
# Rendering code goes here.
|
155 |
+
self.end_frame()
|
156 |
+
|
157 |
+
def make_context_current(self):
|
158 |
+
if self._glfw_window is not None:
|
159 |
+
glfw.make_context_current(self._glfw_window)
|
160 |
+
|
161 |
+
def begin_frame(self):
|
162 |
+
# End previous frame.
|
163 |
+
if self._drawing_frame:
|
164 |
+
self.end_frame()
|
165 |
+
|
166 |
+
# Apply FPS limit.
|
167 |
+
if self._frame_start_time is not None and self._fps_limit is not None:
|
168 |
+
delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
|
169 |
+
if delay > 0:
|
170 |
+
time.sleep(delay)
|
171 |
+
cur_time = time.perf_counter()
|
172 |
+
if self._frame_start_time is not None:
|
173 |
+
self._frame_delta = cur_time - self._frame_start_time
|
174 |
+
self._frame_start_time = cur_time
|
175 |
+
|
176 |
+
# Process events.
|
177 |
+
glfw.poll_events()
|
178 |
+
|
179 |
+
# Begin frame.
|
180 |
+
self._drawing_frame = True
|
181 |
+
self.make_context_current()
|
182 |
+
|
183 |
+
# Initialize GL state.
|
184 |
+
gl.glViewport(0, 0, self.content_width, self.content_height)
|
185 |
+
gl.glMatrixMode(gl.GL_PROJECTION)
|
186 |
+
gl.glLoadIdentity()
|
187 |
+
gl.glTranslate(-1, 1, 0)
|
188 |
+
gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
|
189 |
+
gl.glMatrixMode(gl.GL_MODELVIEW)
|
190 |
+
gl.glLoadIdentity()
|
191 |
+
gl.glEnable(gl.GL_BLEND)
|
192 |
+
gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
|
193 |
+
|
194 |
+
# Clear.
|
195 |
+
gl.glClearColor(0, 0, 0, 1)
|
196 |
+
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
|
197 |
+
|
198 |
+
def end_frame(self):
|
199 |
+
assert self._drawing_frame
|
200 |
+
self._drawing_frame = False
|
201 |
+
|
202 |
+
# Skip frames if requested.
|
203 |
+
if self._skip_frames > 0:
|
204 |
+
self._skip_frames -= 1
|
205 |
+
return
|
206 |
+
|
207 |
+
# Capture frame if requested.
|
208 |
+
if self._capture_next_frame:
|
209 |
+
self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
|
210 |
+
self._capture_next_frame = False
|
211 |
+
|
212 |
+
# Update window.
|
213 |
+
if self._deferred_show:
|
214 |
+
glfw.show_window(self._glfw_window)
|
215 |
+
self._deferred_show = False
|
216 |
+
glfw.swap_buffers(self._glfw_window)
|
217 |
+
|
218 |
+
def _attach_glfw_callbacks(self):
|
219 |
+
glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
|
220 |
+
glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
|
221 |
+
|
222 |
+
def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
|
223 |
+
if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
|
224 |
+
self._esc_pressed = True
|
225 |
+
|
226 |
+
def _glfw_drop_callback(self, _window, paths):
|
227 |
+
self._drag_and_drop_paths = paths
|
228 |
+
|
229 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/gui_utils/imgui_utils.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import contextlib
|
10 |
+
import imgui
|
11 |
+
|
12 |
+
#----------------------------------------------------------------------------
|
13 |
+
|
14 |
+
def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
|
15 |
+
s = imgui.get_style()
|
16 |
+
s.window_padding = [spacing, spacing]
|
17 |
+
s.item_spacing = [spacing, spacing]
|
18 |
+
s.item_inner_spacing = [spacing, spacing]
|
19 |
+
s.columns_min_spacing = spacing
|
20 |
+
s.indent_spacing = indent
|
21 |
+
s.scrollbar_size = scrollbar
|
22 |
+
s.frame_padding = [4, 3]
|
23 |
+
s.window_border_size = 1
|
24 |
+
s.child_border_size = 1
|
25 |
+
s.popup_border_size = 1
|
26 |
+
s.frame_border_size = 1
|
27 |
+
s.window_rounding = 0
|
28 |
+
s.child_rounding = 0
|
29 |
+
s.popup_rounding = 3
|
30 |
+
s.frame_rounding = 3
|
31 |
+
s.scrollbar_rounding = 3
|
32 |
+
s.grab_rounding = 3
|
33 |
+
|
34 |
+
getattr(imgui, f'style_colors_{color_scheme}')(s)
|
35 |
+
c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
|
36 |
+
c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
|
37 |
+
s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
|
38 |
+
|
39 |
+
#----------------------------------------------------------------------------
|
40 |
+
|
41 |
+
@contextlib.contextmanager
|
42 |
+
def grayed_out(cond=True):
|
43 |
+
if cond:
|
44 |
+
s = imgui.get_style()
|
45 |
+
text = s.colors[imgui.COLOR_TEXT_DISABLED]
|
46 |
+
grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
|
47 |
+
back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
|
48 |
+
imgui.push_style_color(imgui.COLOR_TEXT, *text)
|
49 |
+
imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
|
50 |
+
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
|
51 |
+
imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
|
52 |
+
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
|
53 |
+
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
|
54 |
+
imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
|
55 |
+
imgui.push_style_color(imgui.COLOR_BUTTON, *back)
|
56 |
+
imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
|
57 |
+
imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
|
58 |
+
imgui.push_style_color(imgui.COLOR_HEADER, *back)
|
59 |
+
imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
|
60 |
+
imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
|
61 |
+
imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
|
62 |
+
yield
|
63 |
+
imgui.pop_style_color(14)
|
64 |
+
else:
|
65 |
+
yield
|
66 |
+
|
67 |
+
#----------------------------------------------------------------------------
|
68 |
+
|
69 |
+
@contextlib.contextmanager
|
70 |
+
def item_width(width=None):
|
71 |
+
if width is not None:
|
72 |
+
imgui.push_item_width(width)
|
73 |
+
yield
|
74 |
+
imgui.pop_item_width()
|
75 |
+
else:
|
76 |
+
yield
|
77 |
+
|
78 |
+
#----------------------------------------------------------------------------
|
79 |
+
|
80 |
+
def scoped_by_object_id(method):
|
81 |
+
def decorator(self, *args, **kwargs):
|
82 |
+
imgui.push_id(str(id(self)))
|
83 |
+
res = method(self, *args, **kwargs)
|
84 |
+
imgui.pop_id()
|
85 |
+
return res
|
86 |
+
return decorator
|
87 |
+
|
88 |
+
#----------------------------------------------------------------------------
|
89 |
+
|
90 |
+
def button(label, width=0, enabled=True):
|
91 |
+
with grayed_out(not enabled):
|
92 |
+
clicked = imgui.button(label, width=width)
|
93 |
+
clicked = clicked and enabled
|
94 |
+
return clicked
|
95 |
+
|
96 |
+
#----------------------------------------------------------------------------
|
97 |
+
|
98 |
+
def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
|
99 |
+
expanded = False
|
100 |
+
if show:
|
101 |
+
if default:
|
102 |
+
flags |= imgui.TREE_NODE_DEFAULT_OPEN
|
103 |
+
if not enabled:
|
104 |
+
flags |= imgui.TREE_NODE_LEAF
|
105 |
+
with grayed_out(not enabled):
|
106 |
+
expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
|
107 |
+
expanded = expanded and enabled
|
108 |
+
return expanded, visible
|
109 |
+
|
110 |
+
#----------------------------------------------------------------------------
|
111 |
+
|
112 |
+
def popup_button(label, width=0, enabled=True):
|
113 |
+
if button(label, width, enabled):
|
114 |
+
imgui.open_popup(label)
|
115 |
+
opened = imgui.begin_popup(label)
|
116 |
+
return opened
|
117 |
+
|
118 |
+
#----------------------------------------------------------------------------
|
119 |
+
|
120 |
+
def input_text(label, value, buffer_length, flags, width=None, help_text=''):
|
121 |
+
old_value = value
|
122 |
+
color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
|
123 |
+
if value == '':
|
124 |
+
color[-1] *= 0.5
|
125 |
+
with item_width(width):
|
126 |
+
imgui.push_style_color(imgui.COLOR_TEXT, *color)
|
127 |
+
value = value if value != '' else help_text
|
128 |
+
changed, value = imgui.input_text(label, value, buffer_length, flags)
|
129 |
+
value = value if value != help_text else ''
|
130 |
+
imgui.pop_style_color(1)
|
131 |
+
if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
|
132 |
+
changed = (value != old_value)
|
133 |
+
return changed, value
|
134 |
+
|
135 |
+
#----------------------------------------------------------------------------
|
136 |
+
|
137 |
+
def drag_previous_control(enabled=True):
|
138 |
+
dragging = False
|
139 |
+
dx = 0
|
140 |
+
dy = 0
|
141 |
+
if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
|
142 |
+
if enabled:
|
143 |
+
dragging = True
|
144 |
+
dx, dy = imgui.get_mouse_drag_delta()
|
145 |
+
imgui.reset_mouse_drag_delta()
|
146 |
+
imgui.end_drag_drop_source()
|
147 |
+
return dragging, dx, dy
|
148 |
+
|
149 |
+
#----------------------------------------------------------------------------
|
150 |
+
|
151 |
+
def drag_button(label, width=0, enabled=True):
|
152 |
+
clicked = button(label, width=width, enabled=enabled)
|
153 |
+
dragging, dx, dy = drag_previous_control(enabled=enabled)
|
154 |
+
return clicked, dragging, dx, dy
|
155 |
+
|
156 |
+
#----------------------------------------------------------------------------
|
157 |
+
|
158 |
+
def drag_hidden_window(label, x, y, width, height, enabled=True):
|
159 |
+
imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
|
160 |
+
imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
|
161 |
+
imgui.set_next_window_position(x, y)
|
162 |
+
imgui.set_next_window_size(width, height)
|
163 |
+
imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
|
164 |
+
dragging, dx, dy = drag_previous_control(enabled=enabled)
|
165 |
+
imgui.end()
|
166 |
+
imgui.pop_style_color(2)
|
167 |
+
return dragging, dx, dy
|
168 |
+
|
169 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/gui_utils/imgui_window.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import imgui
|
11 |
+
import imgui.integrations.glfw
|
12 |
+
|
13 |
+
from . import glfw_window
|
14 |
+
from . import imgui_utils
|
15 |
+
from . import text_utils
|
16 |
+
|
17 |
+
#----------------------------------------------------------------------------
|
18 |
+
|
19 |
+
class ImguiWindow(glfw_window.GlfwWindow):
|
20 |
+
def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
|
21 |
+
if font is None:
|
22 |
+
font = text_utils.get_default_font()
|
23 |
+
font_sizes = {int(size) for size in font_sizes}
|
24 |
+
super().__init__(title=title, **glfw_kwargs)
|
25 |
+
|
26 |
+
# Init fields.
|
27 |
+
self._imgui_context = None
|
28 |
+
self._imgui_renderer = None
|
29 |
+
self._imgui_fonts = None
|
30 |
+
self._cur_font_size = max(font_sizes)
|
31 |
+
|
32 |
+
# Delete leftover imgui.ini to avoid unexpected behavior.
|
33 |
+
if os.path.isfile('imgui.ini'):
|
34 |
+
os.remove('imgui.ini')
|
35 |
+
|
36 |
+
# Init ImGui.
|
37 |
+
self._imgui_context = imgui.create_context()
|
38 |
+
self._imgui_renderer = _GlfwRenderer(self._glfw_window)
|
39 |
+
self._attach_glfw_callbacks()
|
40 |
+
imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
|
41 |
+
imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
|
42 |
+
self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
|
43 |
+
self._imgui_renderer.refresh_font_texture()
|
44 |
+
|
45 |
+
def close(self):
|
46 |
+
self.make_context_current()
|
47 |
+
self._imgui_fonts = None
|
48 |
+
if self._imgui_renderer is not None:
|
49 |
+
self._imgui_renderer.shutdown()
|
50 |
+
self._imgui_renderer = None
|
51 |
+
if self._imgui_context is not None:
|
52 |
+
#imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
|
53 |
+
self._imgui_context = None
|
54 |
+
super().close()
|
55 |
+
|
56 |
+
def _glfw_key_callback(self, *args):
|
57 |
+
super()._glfw_key_callback(*args)
|
58 |
+
self._imgui_renderer.keyboard_callback(*args)
|
59 |
+
|
60 |
+
@property
|
61 |
+
def font_size(self):
|
62 |
+
return self._cur_font_size
|
63 |
+
|
64 |
+
@property
|
65 |
+
def spacing(self):
|
66 |
+
return round(self._cur_font_size * 0.4)
|
67 |
+
|
68 |
+
def set_font_size(self, target): # Applied on next frame.
|
69 |
+
self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
|
70 |
+
|
71 |
+
def begin_frame(self):
|
72 |
+
# Begin glfw frame.
|
73 |
+
super().begin_frame()
|
74 |
+
|
75 |
+
# Process imgui events.
|
76 |
+
self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
|
77 |
+
if self.content_width > 0 and self.content_height > 0:
|
78 |
+
self._imgui_renderer.process_inputs()
|
79 |
+
|
80 |
+
# Begin imgui frame.
|
81 |
+
imgui.new_frame()
|
82 |
+
imgui.push_font(self._imgui_fonts[self._cur_font_size])
|
83 |
+
imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
|
84 |
+
|
85 |
+
def end_frame(self):
|
86 |
+
imgui.pop_font()
|
87 |
+
imgui.render()
|
88 |
+
imgui.end_frame()
|
89 |
+
self._imgui_renderer.render(imgui.get_draw_data())
|
90 |
+
super().end_frame()
|
91 |
+
|
92 |
+
#----------------------------------------------------------------------------
|
93 |
+
# Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
|
94 |
+
|
95 |
+
class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
|
96 |
+
def __init__(self, *args, **kwargs):
|
97 |
+
super().__init__(*args, **kwargs)
|
98 |
+
self.mouse_wheel_multiplier = 1
|
99 |
+
|
100 |
+
def scroll_callback(self, window, x_offset, y_offset):
|
101 |
+
self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
|
102 |
+
|
103 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/gui_utils/text_utils.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import functools
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import dnnlib
|
13 |
+
import numpy as np
|
14 |
+
import PIL.Image
|
15 |
+
import PIL.ImageFont
|
16 |
+
import scipy.ndimage
|
17 |
+
|
18 |
+
from . import gl_utils
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
def get_default_font():
|
23 |
+
url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
|
24 |
+
return dnnlib.util.open_url(url, return_filename=True)
|
25 |
+
|
26 |
+
#----------------------------------------------------------------------------
|
27 |
+
|
28 |
+
@functools.lru_cache(maxsize=None)
|
29 |
+
def get_pil_font(font=None, size=32):
|
30 |
+
if font is None:
|
31 |
+
font = get_default_font()
|
32 |
+
return PIL.ImageFont.truetype(font=font, size=size)
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def get_array(string, *, dropshadow_radius: int=None, **kwargs):
|
37 |
+
if dropshadow_radius is not None:
|
38 |
+
offset_x = int(np.ceil(dropshadow_radius*2/3))
|
39 |
+
offset_y = int(np.ceil(dropshadow_radius*2/3))
|
40 |
+
return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
|
41 |
+
else:
|
42 |
+
return _get_array_priv(string, **kwargs)
|
43 |
+
|
44 |
+
@functools.lru_cache(maxsize=10000)
|
45 |
+
def _get_array_priv(
|
46 |
+
string: str, *,
|
47 |
+
size: int = 32,
|
48 |
+
max_width: Optional[int]=None,
|
49 |
+
max_height: Optional[int]=None,
|
50 |
+
min_size=10,
|
51 |
+
shrink_coef=0.8,
|
52 |
+
dropshadow_radius: int=None,
|
53 |
+
offset_x: int=None,
|
54 |
+
offset_y: int=None,
|
55 |
+
**kwargs
|
56 |
+
):
|
57 |
+
cur_size = size
|
58 |
+
array = None
|
59 |
+
while True:
|
60 |
+
if dropshadow_radius is not None:
|
61 |
+
# separate implementation for dropshadow text rendering
|
62 |
+
array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
|
63 |
+
else:
|
64 |
+
array = _get_array_impl(string, size=cur_size, **kwargs)
|
65 |
+
height, width, _ = array.shape
|
66 |
+
if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
|
67 |
+
break
|
68 |
+
cur_size = max(int(cur_size * shrink_coef), min_size)
|
69 |
+
return array
|
70 |
+
|
71 |
+
#----------------------------------------------------------------------------
|
72 |
+
|
73 |
+
@functools.lru_cache(maxsize=10000)
|
74 |
+
def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
|
75 |
+
pil_font = get_pil_font(font=font, size=size)
|
76 |
+
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
|
77 |
+
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
|
78 |
+
width = max(line.shape[1] for line in lines)
|
79 |
+
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
|
80 |
+
line_spacing = line_pad if line_pad is not None else size // 2
|
81 |
+
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
|
82 |
+
mask = np.concatenate(lines, axis=0)
|
83 |
+
alpha = mask
|
84 |
+
if outline > 0:
|
85 |
+
mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
|
86 |
+
alpha = mask.astype(np.float32) / 255
|
87 |
+
alpha = scipy.ndimage.gaussian_filter(alpha, outline)
|
88 |
+
alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
|
89 |
+
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
|
90 |
+
alpha = np.maximum(alpha, mask)
|
91 |
+
return np.stack([mask, alpha], axis=-1)
|
92 |
+
|
93 |
+
#----------------------------------------------------------------------------
|
94 |
+
|
95 |
+
@functools.lru_cache(maxsize=10000)
|
96 |
+
def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
|
97 |
+
assert (offset_x > 0) and (offset_y > 0)
|
98 |
+
pil_font = get_pil_font(font=font, size=size)
|
99 |
+
lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
|
100 |
+
lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
|
101 |
+
width = max(line.shape[1] for line in lines)
|
102 |
+
lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
|
103 |
+
line_spacing = line_pad if line_pad is not None else size // 2
|
104 |
+
lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
|
105 |
+
mask = np.concatenate(lines, axis=0)
|
106 |
+
alpha = mask
|
107 |
+
|
108 |
+
mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
|
109 |
+
alpha = mask.astype(np.float32) / 255
|
110 |
+
alpha = scipy.ndimage.gaussian_filter(alpha, radius)
|
111 |
+
alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
|
112 |
+
alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
|
113 |
+
alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
|
114 |
+
alpha = np.maximum(alpha, mask)
|
115 |
+
return np.stack([mask, alpha], axis=-1)
|
116 |
+
|
117 |
+
#----------------------------------------------------------------------------
|
118 |
+
|
119 |
+
@functools.lru_cache(maxsize=10000)
|
120 |
+
def get_texture(string, bilinear=True, mipmap=True, **kwargs):
|
121 |
+
return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
|
122 |
+
|
123 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/legacy.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Converting legacy network pickle into the new format."""
|
10 |
+
|
11 |
+
import click
|
12 |
+
import pickle
|
13 |
+
import re
|
14 |
+
import copy
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
from torch_utils import misc
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
def load_network_pkl(f, force_fp16=False):
|
23 |
+
data = _LegacyUnpickler(f).load()
|
24 |
+
try:
|
25 |
+
# Legacy TensorFlow pickle => convert.
|
26 |
+
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
|
27 |
+
tf_G, tf_D, tf_Gs = data
|
28 |
+
G = convert_tf_generator(tf_G)
|
29 |
+
D = convert_tf_discriminator(tf_D)
|
30 |
+
G_ema = convert_tf_generator(tf_Gs)
|
31 |
+
data = dict(G=G, D=D, G_ema=G_ema)
|
32 |
+
|
33 |
+
# Add missing fields.
|
34 |
+
if 'training_set_kwargs' not in data:
|
35 |
+
data['training_set_kwargs'] = None
|
36 |
+
if 'augment_pipe' not in data:
|
37 |
+
data['augment_pipe'] = None
|
38 |
+
|
39 |
+
# Validate contents.
|
40 |
+
if 'G' in data:
|
41 |
+
assert isinstance(data['G'], torch.nn.Module)
|
42 |
+
if 'D' in data:
|
43 |
+
assert isinstance(data['D'], torch.nn.Module)
|
44 |
+
assert isinstance(data['G_ema'], torch.nn.Module)
|
45 |
+
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
|
46 |
+
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
|
47 |
+
|
48 |
+
# Force FP16.
|
49 |
+
if force_fp16:
|
50 |
+
for key in ['G', 'D', 'G_ema']:
|
51 |
+
old = data[key]
|
52 |
+
kwargs = copy.deepcopy(old.init_kwargs)
|
53 |
+
fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
|
54 |
+
fp16_kwargs.num_fp16_res = 4
|
55 |
+
fp16_kwargs.conv_clamp = 256
|
56 |
+
if kwargs != old.init_kwargs:
|
57 |
+
new = type(old)(**kwargs).eval().requires_grad_(False)
|
58 |
+
misc.copy_params_and_buffers(old, new, require_all=True)
|
59 |
+
data[key] = new
|
60 |
+
except KeyError:
|
61 |
+
# Most likely a StyleGAN-NADA pkl, so pass and return data
|
62 |
+
pass
|
63 |
+
return data
|
64 |
+
|
65 |
+
#----------------------------------------------------------------------------
|
66 |
+
|
67 |
+
class _TFNetworkStub(dnnlib.EasyDict):
|
68 |
+
pass
|
69 |
+
|
70 |
+
class _LegacyUnpickler(pickle.Unpickler):
|
71 |
+
def find_class(self, module, name):
|
72 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
73 |
+
return _TFNetworkStub
|
74 |
+
return super().find_class(module, name)
|
75 |
+
|
76 |
+
#----------------------------------------------------------------------------
|
77 |
+
|
78 |
+
def _collect_tf_params(tf_net):
|
79 |
+
# pylint: disable=protected-access
|
80 |
+
tf_params = dict()
|
81 |
+
def recurse(prefix, tf_net):
|
82 |
+
for name, value in tf_net.variables:
|
83 |
+
tf_params[prefix + name] = value
|
84 |
+
for name, comp in tf_net.components.items():
|
85 |
+
recurse(prefix + name + '/', comp)
|
86 |
+
recurse('', tf_net)
|
87 |
+
return tf_params
|
88 |
+
|
89 |
+
#----------------------------------------------------------------------------
|
90 |
+
|
91 |
+
def _populate_module_params(module, *patterns):
|
92 |
+
for name, tensor in misc.named_params_and_buffers(module):
|
93 |
+
found = False
|
94 |
+
value = None
|
95 |
+
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
|
96 |
+
match = re.fullmatch(pattern, name)
|
97 |
+
if match:
|
98 |
+
found = True
|
99 |
+
if value_fn is not None:
|
100 |
+
value = value_fn(*match.groups())
|
101 |
+
break
|
102 |
+
try:
|
103 |
+
assert found
|
104 |
+
if value is not None:
|
105 |
+
tensor.copy_(torch.from_numpy(np.array(value)))
|
106 |
+
except:
|
107 |
+
print(name, list(tensor.shape))
|
108 |
+
raise
|
109 |
+
|
110 |
+
#----------------------------------------------------------------------------
|
111 |
+
|
112 |
+
def convert_tf_generator(tf_G):
|
113 |
+
if tf_G.version < 4:
|
114 |
+
raise ValueError('TensorFlow pickle version too low')
|
115 |
+
|
116 |
+
# Collect kwargs.
|
117 |
+
tf_kwargs = tf_G.static_kwargs
|
118 |
+
known_kwargs = set()
|
119 |
+
def kwarg(tf_name, default=None, none=None):
|
120 |
+
known_kwargs.add(tf_name)
|
121 |
+
val = tf_kwargs.get(tf_name, default)
|
122 |
+
return val if val is not None else none
|
123 |
+
|
124 |
+
# Convert kwargs.
|
125 |
+
from training import networks_stylegan2
|
126 |
+
network_class = networks_stylegan2.Generator
|
127 |
+
kwargs = dnnlib.EasyDict(
|
128 |
+
z_dim = kwarg('latent_size', 512),
|
129 |
+
c_dim = kwarg('label_size', 0),
|
130 |
+
w_dim = kwarg('dlatent_size', 512),
|
131 |
+
img_resolution = kwarg('resolution', 1024),
|
132 |
+
img_channels = kwarg('num_channels', 3),
|
133 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
134 |
+
channel_max = kwarg('fmap_max', 512),
|
135 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
136 |
+
conv_clamp = kwarg('conv_clamp', None),
|
137 |
+
architecture = kwarg('architecture', 'skip'),
|
138 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
139 |
+
use_noise = kwarg('use_noise', True),
|
140 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
141 |
+
mapping_kwargs = dnnlib.EasyDict(
|
142 |
+
num_layers = kwarg('mapping_layers', 8),
|
143 |
+
embed_features = kwarg('label_fmaps', None),
|
144 |
+
layer_features = kwarg('mapping_fmaps', None),
|
145 |
+
activation = kwarg('mapping_nonlinearity', 'lrelu'),
|
146 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.01),
|
147 |
+
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
|
148 |
+
),
|
149 |
+
)
|
150 |
+
# Check for unknown kwargs.
|
151 |
+
kwarg('truncation_psi')
|
152 |
+
kwarg('truncation_cutoff')
|
153 |
+
kwarg('style_mixing_prob')
|
154 |
+
kwarg('structure')
|
155 |
+
kwarg('conditioning')
|
156 |
+
kwarg('fused_modconv')
|
157 |
+
kwarg('randomize_noise') # PR #173 by @cobanov, for StyleGAN2 models/TF 1.15
|
158 |
+
kwarg('resolution_h') # For transferring from --network=anime1024 --cfg=stylegan2-ext
|
159 |
+
kwarg('resolution_w') # For transferring from --network=anime1024 --cfg=stylegan2-ext
|
160 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
161 |
+
if len(unknown_kwargs) > 0:
|
162 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
163 |
+
|
164 |
+
# Collect params.
|
165 |
+
tf_params = _collect_tf_params(tf_G)
|
166 |
+
|
167 |
+
# Add changes from --cfg=stylegan2-ext; currently can't think of another way to do this
|
168 |
+
if 'resolution_h' and 'resolution_w' in tf_kwargs:
|
169 |
+
kwargs.channel_max = 1024
|
170 |
+
kwargs.w_dim = 1024
|
171 |
+
kwargs.z_dim = 1024
|
172 |
+
kwargs.mapping_kwargs.layer_features = 1024
|
173 |
+
kwargs.mapping_kwargs.num_layers = 4
|
174 |
+
kwargs.channel_base *= 2
|
175 |
+
|
176 |
+
for name, value in list(tf_params.items()):
|
177 |
+
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
|
178 |
+
if match:
|
179 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
180 |
+
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
|
181 |
+
kwargs.synthesis.kwargs.architecture = 'orig'
|
182 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
183 |
+
|
184 |
+
# Convert params.
|
185 |
+
G = network_class(**kwargs).eval().requires_grad_(False)
|
186 |
+
# pylint: disable=unnecessary-lambda
|
187 |
+
# pylint: disable=f-string-without-interpolation
|
188 |
+
_populate_module_params(G,
|
189 |
+
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
|
190 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
|
191 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
|
192 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
|
193 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
|
194 |
+
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
|
195 |
+
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
196 |
+
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
|
197 |
+
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
|
198 |
+
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
|
199 |
+
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
|
200 |
+
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
|
201 |
+
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
202 |
+
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
|
203 |
+
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
|
204 |
+
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
|
205 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
|
206 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
|
207 |
+
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
|
208 |
+
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
|
209 |
+
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
|
210 |
+
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
|
211 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
|
212 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
|
213 |
+
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
|
214 |
+
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
|
215 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
|
216 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
|
217 |
+
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
218 |
+
r'.*\.resample_filter', None,
|
219 |
+
r'.*\.act_filter', None,
|
220 |
+
)
|
221 |
+
return G
|
222 |
+
|
223 |
+
#----------------------------------------------------------------------------
|
224 |
+
|
225 |
+
def convert_tf_discriminator(tf_D):
|
226 |
+
if tf_D.version < 4:
|
227 |
+
raise ValueError('TensorFlow pickle version too low')
|
228 |
+
|
229 |
+
# Collect kwargs.
|
230 |
+
tf_kwargs = tf_D.static_kwargs
|
231 |
+
known_kwargs = set()
|
232 |
+
def kwarg(tf_name, default=None):
|
233 |
+
known_kwargs.add(tf_name)
|
234 |
+
return tf_kwargs.get(tf_name, default)
|
235 |
+
|
236 |
+
# Convert kwargs.
|
237 |
+
kwargs = dnnlib.EasyDict(
|
238 |
+
c_dim = kwarg('label_size', 0),
|
239 |
+
img_resolution = kwarg('resolution', 1024),
|
240 |
+
img_channels = kwarg('num_channels', 3),
|
241 |
+
architecture = kwarg('architecture', 'resnet'),
|
242 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
243 |
+
channel_max = kwarg('fmap_max', 512),
|
244 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
245 |
+
conv_clamp = kwarg('conv_clamp', None),
|
246 |
+
cmap_dim = kwarg('mapping_fmaps', None),
|
247 |
+
block_kwargs = dnnlib.EasyDict(
|
248 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
249 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
250 |
+
freeze_layers = kwarg('freeze_layers', 0),
|
251 |
+
),
|
252 |
+
mapping_kwargs = dnnlib.EasyDict(
|
253 |
+
num_layers = kwarg('mapping_layers', 0),
|
254 |
+
embed_features = kwarg('mapping_fmaps', None),
|
255 |
+
layer_features = kwarg('mapping_fmaps', None),
|
256 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
257 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.1),
|
258 |
+
),
|
259 |
+
epilogue_kwargs = dnnlib.EasyDict(
|
260 |
+
mbstd_group_size = kwarg('mbstd_group_size', None),
|
261 |
+
mbstd_num_channels = kwarg('mbstd_num_features', 1),
|
262 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
263 |
+
),
|
264 |
+
)
|
265 |
+
|
266 |
+
# Check for unknown kwargs.
|
267 |
+
kwarg('structure')
|
268 |
+
kwarg('conditioning')
|
269 |
+
kwarg('resolution_h') # For transferring from --network=anime1024 --cfg=stylegan2-ext
|
270 |
+
kwarg('resolution_w') # For transferring from --network=anime1024 --cfg=stylegan2-ext
|
271 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
272 |
+
if len(unknown_kwargs) > 0:
|
273 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
274 |
+
|
275 |
+
# Add changes from --cfg=stylegan2-ext. (I really can't think of another way tbh)
|
276 |
+
if 'resolution_h' and 'resolution_w' in tf_kwargs:
|
277 |
+
kwargs.epilogue_kwargs.mbstd_num_channels = 4
|
278 |
+
|
279 |
+
# Collect params.
|
280 |
+
tf_params = _collect_tf_params(tf_D)
|
281 |
+
for name, value in list(tf_params.items()):
|
282 |
+
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
|
283 |
+
if match:
|
284 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
285 |
+
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
|
286 |
+
kwargs.architecture = 'orig'
|
287 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
288 |
+
|
289 |
+
# Convert params.
|
290 |
+
from training import networks_stylegan2
|
291 |
+
D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
|
292 |
+
# pylint: disable=unnecessary-lambda
|
293 |
+
# pylint: disable=f-string-without-interpolation
|
294 |
+
_populate_module_params(D,
|
295 |
+
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
|
296 |
+
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
|
297 |
+
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
|
298 |
+
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
|
299 |
+
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
|
300 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
|
301 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
|
302 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
|
303 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
|
304 |
+
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
305 |
+
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
|
306 |
+
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
|
307 |
+
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
|
308 |
+
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
|
309 |
+
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
|
310 |
+
r'.*\.resample_filter', None,
|
311 |
+
)
|
312 |
+
return D
|
313 |
+
|
314 |
+
#----------------------------------------------------------------------------
|
315 |
+
|
316 |
+
@click.command()
|
317 |
+
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
|
318 |
+
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
|
319 |
+
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
|
320 |
+
def convert_network_pickle(source, dest, force_fp16):
|
321 |
+
"""Convert legacy network pickle into the native PyTorch format.
|
322 |
+
|
323 |
+
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
|
324 |
+
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
|
325 |
+
|
326 |
+
Example:
|
327 |
+
|
328 |
+
\b
|
329 |
+
python legacy.py \\
|
330 |
+
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
|
331 |
+
--dest=stylegan2-cat-config-f.pkl
|
332 |
+
"""
|
333 |
+
print(f'Loading "{source}"...')
|
334 |
+
with dnnlib.util.open_url(source) as f:
|
335 |
+
data = load_network_pkl(f, force_fp16=force_fp16)
|
336 |
+
print(f'Saving "{dest}"...')
|
337 |
+
with open(dest, 'wb') as f:
|
338 |
+
pickle.dump(data, f)
|
339 |
+
print('Done.')
|
340 |
+
|
341 |
+
#----------------------------------------------------------------------------
|
342 |
+
|
343 |
+
if __name__ == "__main__":
|
344 |
+
convert_network_pickle() # pylint: disable=no-value-for-parameter
|
345 |
+
|
346 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/metrics/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
stylegan3-fun/metrics/equivariance.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
|
10 |
+
"Alias-Free Generative Adversarial Networks"."""
|
11 |
+
|
12 |
+
import copy
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.fft
|
16 |
+
from torch_utils.ops import upfirdn2d
|
17 |
+
from . import metric_utils
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
# Utilities.
|
21 |
+
|
22 |
+
def sinc(x):
|
23 |
+
y = (x * np.pi).abs()
|
24 |
+
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
|
25 |
+
return torch.where(y < 1e-30, torch.ones_like(x), z)
|
26 |
+
|
27 |
+
def lanczos_window(x, a):
|
28 |
+
x = x.abs() / a
|
29 |
+
return torch.where(x < 1, sinc(x), torch.zeros_like(x))
|
30 |
+
|
31 |
+
def rotation_matrix(angle):
|
32 |
+
angle = torch.as_tensor(angle).to(torch.float32)
|
33 |
+
mat = torch.eye(3, device=angle.device)
|
34 |
+
mat[0, 0] = angle.cos()
|
35 |
+
mat[0, 1] = angle.sin()
|
36 |
+
mat[1, 0] = -angle.sin()
|
37 |
+
mat[1, 1] = angle.cos()
|
38 |
+
return mat
|
39 |
+
|
40 |
+
#----------------------------------------------------------------------------
|
41 |
+
# Apply integer translation to a batch of 2D images. Corresponds to the
|
42 |
+
# operator T_x in Appendix E.1.
|
43 |
+
|
44 |
+
def apply_integer_translation(x, tx, ty):
|
45 |
+
_N, _C, H, W = x.shape
|
46 |
+
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
47 |
+
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
48 |
+
ix = tx.round().to(torch.int64)
|
49 |
+
iy = ty.round().to(torch.int64)
|
50 |
+
|
51 |
+
z = torch.zeros_like(x)
|
52 |
+
m = torch.zeros_like(x)
|
53 |
+
if abs(ix) < W and abs(iy) < H:
|
54 |
+
y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
|
55 |
+
z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
|
56 |
+
m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
|
57 |
+
return z, m
|
58 |
+
|
59 |
+
#----------------------------------------------------------------------------
|
60 |
+
# Apply integer translation to a batch of 2D images. Corresponds to the
|
61 |
+
# operator T_x in Appendix E.2.
|
62 |
+
|
63 |
+
def apply_fractional_translation(x, tx, ty, a=3):
|
64 |
+
_N, _C, H, W = x.shape
|
65 |
+
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
66 |
+
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
67 |
+
ix = tx.floor().to(torch.int64)
|
68 |
+
iy = ty.floor().to(torch.int64)
|
69 |
+
fx = tx - ix
|
70 |
+
fy = ty - iy
|
71 |
+
b = a - 1
|
72 |
+
|
73 |
+
z = torch.zeros_like(x)
|
74 |
+
zx0 = max(ix - b, 0)
|
75 |
+
zy0 = max(iy - b, 0)
|
76 |
+
zx1 = min(ix + a, 0) + W
|
77 |
+
zy1 = min(iy + a, 0) + H
|
78 |
+
if zx0 < zx1 and zy0 < zy1:
|
79 |
+
taps = torch.arange(a * 2, device=x.device) - b
|
80 |
+
filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
|
81 |
+
filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
|
82 |
+
y = x
|
83 |
+
y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
|
84 |
+
y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
|
85 |
+
y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
|
86 |
+
z[:, :, zy0:zy1, zx0:zx1] = y
|
87 |
+
|
88 |
+
m = torch.zeros_like(x)
|
89 |
+
mx0 = max(ix + a, 0)
|
90 |
+
my0 = max(iy + a, 0)
|
91 |
+
mx1 = min(ix - b, 0) + W
|
92 |
+
my1 = min(iy - b, 0) + H
|
93 |
+
if mx0 < mx1 and my0 < my1:
|
94 |
+
m[:, :, my0:my1, mx0:mx1] = 1
|
95 |
+
return z, m
|
96 |
+
|
97 |
+
#----------------------------------------------------------------------------
|
98 |
+
# Construct an oriented low-pass filter that applies the appropriate
|
99 |
+
# bandlimit with respect to the input and output of the given affine 2D
|
100 |
+
# image transformation.
|
101 |
+
|
102 |
+
def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
|
103 |
+
assert a <= amax < aflt
|
104 |
+
mat = torch.as_tensor(mat).to(torch.float32)
|
105 |
+
|
106 |
+
# Construct 2D filter taps in input & output coordinate spaces.
|
107 |
+
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
|
108 |
+
yi, xi = torch.meshgrid(taps, taps)
|
109 |
+
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
|
110 |
+
|
111 |
+
# Convolution of two oriented 2D sinc filters.
|
112 |
+
fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
|
113 |
+
fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
|
114 |
+
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
|
115 |
+
|
116 |
+
# Convolution of two oriented 2D Lanczos windows.
|
117 |
+
wi = lanczos_window(xi, a) * lanczos_window(yi, a)
|
118 |
+
wo = lanczos_window(xo, a) * lanczos_window(yo, a)
|
119 |
+
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
|
120 |
+
|
121 |
+
# Construct windowed FIR filter.
|
122 |
+
f = f * w
|
123 |
+
|
124 |
+
# Finalize.
|
125 |
+
c = (aflt - amax) * up
|
126 |
+
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
|
127 |
+
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
|
128 |
+
f = f / f.sum([0,2], keepdim=True) / (up ** 2)
|
129 |
+
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
|
130 |
+
return f
|
131 |
+
|
132 |
+
#----------------------------------------------------------------------------
|
133 |
+
# Apply the given affine transformation to a batch of 2D images.
|
134 |
+
|
135 |
+
def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
|
136 |
+
_N, _C, H, W = x.shape
|
137 |
+
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
|
138 |
+
|
139 |
+
# Construct filter.
|
140 |
+
f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
|
141 |
+
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
|
142 |
+
p = f.shape[0] // 2
|
143 |
+
|
144 |
+
# Construct sampling grid.
|
145 |
+
theta = mat.inverse()
|
146 |
+
theta[:2, 2] *= 2
|
147 |
+
theta[0, 2] += 1 / up / W
|
148 |
+
theta[1, 2] += 1 / up / H
|
149 |
+
theta[0, :] *= W / (W + p / up * 2)
|
150 |
+
theta[1, :] *= H / (H + p / up * 2)
|
151 |
+
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
|
152 |
+
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
|
153 |
+
|
154 |
+
# Resample image.
|
155 |
+
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
|
156 |
+
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
|
157 |
+
|
158 |
+
# Form mask.
|
159 |
+
m = torch.zeros_like(y)
|
160 |
+
c = p * 2 + 1
|
161 |
+
m[:, :, c:-c, c:-c] = 1
|
162 |
+
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
|
163 |
+
return z, m
|
164 |
+
|
165 |
+
#----------------------------------------------------------------------------
|
166 |
+
# Apply fractional rotation to a batch of 2D images. Corresponds to the
|
167 |
+
# operator R_\alpha in Appendix E.3.
|
168 |
+
|
169 |
+
def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
|
170 |
+
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
171 |
+
mat = rotation_matrix(angle)
|
172 |
+
return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
|
173 |
+
|
174 |
+
#----------------------------------------------------------------------------
|
175 |
+
# Modify the frequency content of a batch of 2D images as if they had undergo
|
176 |
+
# fractional rotation -- but without actually rotating them. Corresponds to
|
177 |
+
# the operator R^*_\alpha in Appendix E.3.
|
178 |
+
|
179 |
+
def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
|
180 |
+
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
181 |
+
mat = rotation_matrix(-angle)
|
182 |
+
f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
|
183 |
+
y = upfirdn2d.filter2d(x=x, f=f)
|
184 |
+
m = torch.zeros_like(y)
|
185 |
+
c = f.shape[0] // 2
|
186 |
+
m[:, :, c:-c, c:-c] = 1
|
187 |
+
return y, m
|
188 |
+
|
189 |
+
#----------------------------------------------------------------------------
|
190 |
+
# Compute the selected equivariance metrics for the given generator.
|
191 |
+
|
192 |
+
def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
|
193 |
+
assert compute_eqt_int or compute_eqt_frac or compute_eqr
|
194 |
+
|
195 |
+
# Setup generator and labels.
|
196 |
+
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
197 |
+
I = torch.eye(3, device=opts.device)
|
198 |
+
M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
|
199 |
+
if M is None:
|
200 |
+
raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
|
201 |
+
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
|
202 |
+
|
203 |
+
# Sampling loop.
|
204 |
+
sums = None
|
205 |
+
progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
|
206 |
+
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
207 |
+
progress.update(batch_start)
|
208 |
+
s = []
|
209 |
+
|
210 |
+
# Randomize noise buffers, if any.
|
211 |
+
for name, buf in G.named_buffers():
|
212 |
+
if name.endswith('.noise_const'):
|
213 |
+
buf.copy_(torch.randn_like(buf))
|
214 |
+
|
215 |
+
# Run mapping network.
|
216 |
+
z = torch.randn([batch_size, G.z_dim], device=opts.device)
|
217 |
+
c = next(c_iter)
|
218 |
+
ws = G.mapping(z=z, c=c)
|
219 |
+
|
220 |
+
# Generate reference image.
|
221 |
+
M[:] = I
|
222 |
+
orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
223 |
+
|
224 |
+
# Integer translation (EQ-T).
|
225 |
+
if compute_eqt_int:
|
226 |
+
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
227 |
+
t = (t * G.img_resolution).round() / G.img_resolution
|
228 |
+
M[:] = I
|
229 |
+
M[:2, 2] = -t
|
230 |
+
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
231 |
+
ref, mask = apply_integer_translation(orig, t[0], t[1])
|
232 |
+
s += [(ref - img).square() * mask, mask]
|
233 |
+
|
234 |
+
# Fractional translation (EQ-T_frac).
|
235 |
+
if compute_eqt_frac:
|
236 |
+
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
237 |
+
M[:] = I
|
238 |
+
M[:2, 2] = -t
|
239 |
+
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
240 |
+
ref, mask = apply_fractional_translation(orig, t[0], t[1])
|
241 |
+
s += [(ref - img).square() * mask, mask]
|
242 |
+
|
243 |
+
# Rotation (EQ-R).
|
244 |
+
if compute_eqr:
|
245 |
+
angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
|
246 |
+
M[:] = rotation_matrix(-angle)
|
247 |
+
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
248 |
+
ref, ref_mask = apply_fractional_rotation(orig, angle)
|
249 |
+
pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
|
250 |
+
mask = ref_mask * pseudo_mask
|
251 |
+
s += [(ref - pseudo).square() * mask, mask]
|
252 |
+
|
253 |
+
# Accumulate results.
|
254 |
+
s = torch.stack([x.to(torch.float64).sum() for x in s])
|
255 |
+
sums = sums + s if sums is not None else s
|
256 |
+
progress.update(num_samples)
|
257 |
+
|
258 |
+
# Compute PSNRs.
|
259 |
+
if opts.num_gpus > 1:
|
260 |
+
torch.distributed.all_reduce(sums)
|
261 |
+
sums = sums.cpu()
|
262 |
+
mses = sums[0::2] / sums[1::2]
|
263 |
+
psnrs = np.log10(2) * 20 - mses.log10() * 10
|
264 |
+
psnrs = tuple(psnrs.numpy())
|
265 |
+
return psnrs[0] if len(psnrs) == 1 else psnrs
|
266 |
+
|
267 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/metrics/frechet_inception_distance.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Frechet Inception Distance (FID) from the paper
|
10 |
+
"GANs trained by a two time-scale update rule converge to a local Nash
|
11 |
+
equilibrium". Matches the original implementation by Heusel et al. at
|
12 |
+
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import scipy.linalg
|
16 |
+
from . import metric_utils
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
def compute_fid(opts, max_real, num_gen):
|
21 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
22 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
23 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
24 |
+
|
25 |
+
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
26 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
27 |
+
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
|
28 |
+
|
29 |
+
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
|
30 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
31 |
+
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
|
32 |
+
|
33 |
+
if opts.rank != 0:
|
34 |
+
return float('nan')
|
35 |
+
|
36 |
+
m = np.square(mu_gen - mu_real).sum()
|
37 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
38 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
39 |
+
return float(fid)
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/metrics/inception_score.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Inception Score (IS) from the paper "Improved techniques for training
|
10 |
+
GANs". Matches the original implementation by Salimans et al. at
|
11 |
+
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from . import metric_utils
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
|
18 |
+
def compute_is(opts, num_gen, num_splits):
|
19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
21 |
+
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
22 |
+
|
23 |
+
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
24 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
25 |
+
capture_all=True, max_items=num_gen).get_all()
|
26 |
+
|
27 |
+
if opts.rank != 0:
|
28 |
+
return float('nan'), float('nan')
|
29 |
+
|
30 |
+
scores = []
|
31 |
+
for i in range(num_splits):
|
32 |
+
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
33 |
+
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
34 |
+
kl = np.mean(np.sum(kl, axis=1))
|
35 |
+
scores.append(np.exp(kl))
|
36 |
+
return float(np.mean(scores)), float(np.std(scores))
|
37 |
+
|
38 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/metrics/kernel_inception_distance.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
10 |
+
GANs". Matches the original implementation by Binkowski et al. at
|
11 |
+
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from . import metric_utils
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
|
18 |
+
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
21 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
22 |
+
|
23 |
+
real_features = metric_utils.compute_feature_stats_for_dataset(
|
24 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
25 |
+
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
|
26 |
+
|
27 |
+
gen_features = metric_utils.compute_feature_stats_for_generator(
|
28 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
29 |
+
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
|
30 |
+
|
31 |
+
if opts.rank != 0:
|
32 |
+
return float('nan')
|
33 |
+
|
34 |
+
n = real_features.shape[1]
|
35 |
+
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
|
36 |
+
t = 0
|
37 |
+
for _subset_idx in range(num_subsets):
|
38 |
+
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
|
39 |
+
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
|
40 |
+
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
41 |
+
b = (x @ y.T / n + 1) ** 3
|
42 |
+
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
43 |
+
kid = t / num_subsets / m
|
44 |
+
return float(kid)
|
45 |
+
|
46 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/metrics/metric_main.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Main API for computing and reporting quality metrics."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import json
|
14 |
+
import torch
|
15 |
+
import dnnlib
|
16 |
+
|
17 |
+
from . import metric_utils
|
18 |
+
from . import frechet_inception_distance
|
19 |
+
from . import kernel_inception_distance
|
20 |
+
from . import precision_recall
|
21 |
+
from . import perceptual_path_length
|
22 |
+
from . import inception_score
|
23 |
+
from . import equivariance
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
_metric_dict = dict() # name => fn
|
28 |
+
|
29 |
+
def register_metric(fn):
|
30 |
+
assert callable(fn)
|
31 |
+
_metric_dict[fn.__name__] = fn
|
32 |
+
return fn
|
33 |
+
|
34 |
+
def is_valid_metric(metric):
|
35 |
+
return metric in _metric_dict
|
36 |
+
|
37 |
+
def list_valid_metrics():
|
38 |
+
return list(_metric_dict.keys())
|
39 |
+
|
40 |
+
#----------------------------------------------------------------------------
|
41 |
+
|
42 |
+
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
43 |
+
assert is_valid_metric(metric)
|
44 |
+
opts = metric_utils.MetricOptions(**kwargs)
|
45 |
+
|
46 |
+
# Calculate.
|
47 |
+
start_time = time.time()
|
48 |
+
results = _metric_dict[metric](opts)
|
49 |
+
total_time = time.time() - start_time
|
50 |
+
|
51 |
+
# Broadcast results.
|
52 |
+
for key, value in list(results.items()):
|
53 |
+
if opts.num_gpus > 1:
|
54 |
+
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
55 |
+
torch.distributed.broadcast(tensor=value, src=0)
|
56 |
+
value = float(value.cpu())
|
57 |
+
results[key] = value
|
58 |
+
|
59 |
+
# Decorate with metadata.
|
60 |
+
return dnnlib.EasyDict(
|
61 |
+
results = dnnlib.EasyDict(results),
|
62 |
+
metric = metric,
|
63 |
+
total_time = total_time,
|
64 |
+
total_time_str = dnnlib.util.format_time(total_time),
|
65 |
+
num_gpus = opts.num_gpus,
|
66 |
+
)
|
67 |
+
|
68 |
+
#----------------------------------------------------------------------------
|
69 |
+
|
70 |
+
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
71 |
+
metric = result_dict['metric']
|
72 |
+
assert is_valid_metric(metric)
|
73 |
+
if run_dir is not None and snapshot_pkl is not None:
|
74 |
+
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
75 |
+
|
76 |
+
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
77 |
+
print(jsonl_line)
|
78 |
+
if run_dir is not None and os.path.isdir(run_dir):
|
79 |
+
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
80 |
+
f.write(jsonl_line + '\n')
|
81 |
+
|
82 |
+
#----------------------------------------------------------------------------
|
83 |
+
# Recommended metrics.
|
84 |
+
|
85 |
+
@register_metric
|
86 |
+
def fid50k_full(opts):
|
87 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
88 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
89 |
+
return dict(fid50k_full=fid)
|
90 |
+
|
91 |
+
@register_metric
|
92 |
+
def kid50k_full(opts):
|
93 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
94 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
95 |
+
return dict(kid50k_full=kid)
|
96 |
+
|
97 |
+
@register_metric
|
98 |
+
def pr50k3_full(opts):
|
99 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
100 |
+
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
101 |
+
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
|
102 |
+
|
103 |
+
@register_metric
|
104 |
+
def ppl2_wend(opts):
|
105 |
+
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
|
106 |
+
return dict(ppl2_wend=ppl)
|
107 |
+
|
108 |
+
@register_metric
|
109 |
+
def eqt50k_int(opts):
|
110 |
+
opts.G_kwargs.update(force_fp32=True)
|
111 |
+
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
|
112 |
+
return dict(eqt50k_int=psnr)
|
113 |
+
|
114 |
+
@register_metric
|
115 |
+
def eqt50k_frac(opts):
|
116 |
+
opts.G_kwargs.update(force_fp32=True)
|
117 |
+
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
|
118 |
+
return dict(eqt50k_frac=psnr)
|
119 |
+
|
120 |
+
@register_metric
|
121 |
+
def eqr50k(opts):
|
122 |
+
opts.G_kwargs.update(force_fp32=True)
|
123 |
+
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
|
124 |
+
return dict(eqr50k=psnr)
|
125 |
+
|
126 |
+
#----------------------------------------------------------------------------
|
127 |
+
# Legacy metrics.
|
128 |
+
|
129 |
+
@register_metric
|
130 |
+
def fid50k(opts):
|
131 |
+
opts.dataset_kwargs.update(max_size=None)
|
132 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
133 |
+
return dict(fid50k=fid)
|
134 |
+
|
135 |
+
@register_metric
|
136 |
+
def kid50k(opts):
|
137 |
+
opts.dataset_kwargs.update(max_size=None)
|
138 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
139 |
+
return dict(kid50k=kid)
|
140 |
+
|
141 |
+
@register_metric
|
142 |
+
def pr50k3(opts):
|
143 |
+
opts.dataset_kwargs.update(max_size=None)
|
144 |
+
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
145 |
+
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
|
146 |
+
|
147 |
+
@register_metric
|
148 |
+
def is50k(opts):
|
149 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
150 |
+
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
151 |
+
return dict(is50k_mean=mean, is50k_std=std)
|
152 |
+
|
153 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/metrics/metric_utils.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utilities used internally by the quality metrics."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import hashlib
|
14 |
+
import pickle
|
15 |
+
import copy
|
16 |
+
import uuid
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import dnnlib
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
class MetricOptions:
|
24 |
+
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
|
25 |
+
assert 0 <= rank < num_gpus
|
26 |
+
self.G = G
|
27 |
+
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
|
28 |
+
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
|
29 |
+
self.num_gpus = num_gpus
|
30 |
+
self.rank = rank
|
31 |
+
self.device = device if device is not None else torch.device('cuda', rank)
|
32 |
+
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
|
33 |
+
self.cache = cache
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
_feature_detector_cache = dict()
|
38 |
+
|
39 |
+
def get_feature_detector_name(url):
|
40 |
+
return os.path.splitext(url.split('/')[-1])[0]
|
41 |
+
|
42 |
+
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
|
43 |
+
assert 0 <= rank < num_gpus
|
44 |
+
key = (url, device)
|
45 |
+
if key not in _feature_detector_cache:
|
46 |
+
is_leader = (rank == 0)
|
47 |
+
if not is_leader and num_gpus > 1:
|
48 |
+
torch.distributed.barrier() # leader goes first
|
49 |
+
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
|
50 |
+
_feature_detector_cache[key] = pickle.load(f).to(device)
|
51 |
+
if is_leader and num_gpus > 1:
|
52 |
+
torch.distributed.barrier() # others follow
|
53 |
+
return _feature_detector_cache[key]
|
54 |
+
|
55 |
+
#----------------------------------------------------------------------------
|
56 |
+
|
57 |
+
def iterate_random_labels(opts, batch_size):
|
58 |
+
if opts.G.c_dim == 0:
|
59 |
+
c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
|
60 |
+
while True:
|
61 |
+
yield c
|
62 |
+
else:
|
63 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
64 |
+
while True:
|
65 |
+
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
|
66 |
+
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
67 |
+
yield c
|
68 |
+
|
69 |
+
#----------------------------------------------------------------------------
|
70 |
+
|
71 |
+
class FeatureStats:
|
72 |
+
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
|
73 |
+
self.capture_all = capture_all
|
74 |
+
self.capture_mean_cov = capture_mean_cov
|
75 |
+
self.max_items = max_items
|
76 |
+
self.num_items = 0
|
77 |
+
self.num_features = None
|
78 |
+
self.all_features = None
|
79 |
+
self.raw_mean = None
|
80 |
+
self.raw_cov = None
|
81 |
+
|
82 |
+
def set_num_features(self, num_features):
|
83 |
+
if self.num_features is not None:
|
84 |
+
assert num_features == self.num_features
|
85 |
+
else:
|
86 |
+
self.num_features = num_features
|
87 |
+
self.all_features = []
|
88 |
+
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
89 |
+
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
90 |
+
|
91 |
+
def is_full(self):
|
92 |
+
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
93 |
+
|
94 |
+
def append(self, x):
|
95 |
+
x = np.asarray(x, dtype=np.float32)
|
96 |
+
assert x.ndim == 2
|
97 |
+
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
98 |
+
if self.num_items >= self.max_items:
|
99 |
+
return
|
100 |
+
x = x[:self.max_items - self.num_items]
|
101 |
+
|
102 |
+
self.set_num_features(x.shape[1])
|
103 |
+
self.num_items += x.shape[0]
|
104 |
+
if self.capture_all:
|
105 |
+
self.all_features.append(x)
|
106 |
+
if self.capture_mean_cov:
|
107 |
+
x64 = x.astype(np.float64)
|
108 |
+
self.raw_mean += x64.sum(axis=0)
|
109 |
+
self.raw_cov += x64.T @ x64
|
110 |
+
|
111 |
+
def append_torch(self, x, num_gpus=1, rank=0):
|
112 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
113 |
+
assert 0 <= rank < num_gpus
|
114 |
+
if num_gpus > 1:
|
115 |
+
ys = []
|
116 |
+
for src in range(num_gpus):
|
117 |
+
y = x.clone()
|
118 |
+
torch.distributed.broadcast(y, src=src)
|
119 |
+
ys.append(y)
|
120 |
+
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
121 |
+
self.append(x.cpu().numpy())
|
122 |
+
|
123 |
+
def get_all(self):
|
124 |
+
assert self.capture_all
|
125 |
+
return np.concatenate(self.all_features, axis=0)
|
126 |
+
|
127 |
+
def get_all_torch(self):
|
128 |
+
return torch.from_numpy(self.get_all())
|
129 |
+
|
130 |
+
def get_mean_cov(self):
|
131 |
+
assert self.capture_mean_cov
|
132 |
+
mean = self.raw_mean / self.num_items
|
133 |
+
cov = self.raw_cov / self.num_items
|
134 |
+
cov = cov - np.outer(mean, mean)
|
135 |
+
return mean, cov
|
136 |
+
|
137 |
+
def save(self, pkl_file):
|
138 |
+
with open(pkl_file, 'wb') as f:
|
139 |
+
pickle.dump(self.__dict__, f)
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def load(pkl_file):
|
143 |
+
with open(pkl_file, 'rb') as f:
|
144 |
+
s = dnnlib.EasyDict(pickle.load(f))
|
145 |
+
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
|
146 |
+
obj.__dict__.update(s)
|
147 |
+
return obj
|
148 |
+
|
149 |
+
#----------------------------------------------------------------------------
|
150 |
+
|
151 |
+
class ProgressMonitor:
|
152 |
+
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
|
153 |
+
self.tag = tag
|
154 |
+
self.num_items = num_items
|
155 |
+
self.verbose = verbose
|
156 |
+
self.flush_interval = flush_interval
|
157 |
+
self.progress_fn = progress_fn
|
158 |
+
self.pfn_lo = pfn_lo
|
159 |
+
self.pfn_hi = pfn_hi
|
160 |
+
self.pfn_total = pfn_total
|
161 |
+
self.start_time = time.time()
|
162 |
+
self.batch_time = self.start_time
|
163 |
+
self.batch_items = 0
|
164 |
+
if self.progress_fn is not None:
|
165 |
+
self.progress_fn(self.pfn_lo, self.pfn_total)
|
166 |
+
|
167 |
+
def update(self, cur_items):
|
168 |
+
assert (self.num_items is None) or (cur_items <= self.num_items)
|
169 |
+
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
|
170 |
+
return
|
171 |
+
cur_time = time.time()
|
172 |
+
total_time = cur_time - self.start_time
|
173 |
+
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
|
174 |
+
if (self.verbose) and (self.tag is not None):
|
175 |
+
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
|
176 |
+
self.batch_time = cur_time
|
177 |
+
self.batch_items = cur_items
|
178 |
+
|
179 |
+
if (self.progress_fn is not None) and (self.num_items is not None):
|
180 |
+
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
|
181 |
+
|
182 |
+
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
|
183 |
+
return ProgressMonitor(
|
184 |
+
tag = tag,
|
185 |
+
num_items = num_items,
|
186 |
+
flush_interval = flush_interval,
|
187 |
+
verbose = self.verbose,
|
188 |
+
progress_fn = self.progress_fn,
|
189 |
+
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
|
190 |
+
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
|
191 |
+
pfn_total = self.pfn_total,
|
192 |
+
)
|
193 |
+
|
194 |
+
#----------------------------------------------------------------------------
|
195 |
+
|
196 |
+
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
|
197 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
198 |
+
if data_loader_kwargs is None:
|
199 |
+
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
200 |
+
|
201 |
+
# Try to lookup from cache.
|
202 |
+
cache_file = None
|
203 |
+
if opts.cache:
|
204 |
+
# Choose cache file name.
|
205 |
+
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
|
206 |
+
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
|
207 |
+
cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
|
208 |
+
cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
|
209 |
+
|
210 |
+
# Check if the file exists (all processes must agree).
|
211 |
+
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
|
212 |
+
if opts.num_gpus > 1:
|
213 |
+
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
|
214 |
+
torch.distributed.broadcast(tensor=flag, src=0)
|
215 |
+
flag = (float(flag.cpu()) != 0)
|
216 |
+
|
217 |
+
# Load.
|
218 |
+
if flag:
|
219 |
+
return FeatureStats.load(cache_file)
|
220 |
+
|
221 |
+
# Initialize.
|
222 |
+
num_items = len(dataset)
|
223 |
+
if max_items is not None:
|
224 |
+
num_items = min(num_items, max_items)
|
225 |
+
stats = FeatureStats(max_items=num_items, **stats_kwargs)
|
226 |
+
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
227 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
228 |
+
|
229 |
+
# Main loop.
|
230 |
+
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
231 |
+
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
232 |
+
if images.shape[1] == 1:
|
233 |
+
images = images.repeat([1, 3, 1, 1])
|
234 |
+
features = detector(images.to(opts.device), **detector_kwargs)
|
235 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
236 |
+
progress.update(stats.num_items)
|
237 |
+
|
238 |
+
# Save to cache.
|
239 |
+
if cache_file is not None and opts.rank == 0:
|
240 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
241 |
+
temp_file = cache_file + '.' + uuid.uuid4().hex
|
242 |
+
stats.save(temp_file)
|
243 |
+
os.replace(temp_file, cache_file) # atomic
|
244 |
+
return stats
|
245 |
+
|
246 |
+
#----------------------------------------------------------------------------
|
247 |
+
|
248 |
+
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
|
249 |
+
if batch_gen is None:
|
250 |
+
batch_gen = min(batch_size, 4)
|
251 |
+
assert batch_size % batch_gen == 0
|
252 |
+
|
253 |
+
# Setup generator and labels.
|
254 |
+
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
255 |
+
c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
|
256 |
+
|
257 |
+
# Initialize.
|
258 |
+
stats = FeatureStats(**stats_kwargs)
|
259 |
+
assert stats.max_items is not None
|
260 |
+
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
261 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
262 |
+
|
263 |
+
# Main loop.
|
264 |
+
while not stats.is_full():
|
265 |
+
images = []
|
266 |
+
for _i in range(batch_size // batch_gen):
|
267 |
+
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
|
268 |
+
img = G(z=z, c=next(c_iter), **opts.G_kwargs)
|
269 |
+
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
270 |
+
images.append(img)
|
271 |
+
images = torch.cat(images)
|
272 |
+
if images.shape[1] == 1:
|
273 |
+
images = images.repeat([1, 3, 1, 1])
|
274 |
+
features = detector(images, **detector_kwargs)
|
275 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
276 |
+
progress.update(stats.num_items)
|
277 |
+
return stats
|
278 |
+
|
279 |
+
#----------------------------------------------------------------------------
|
stylegan3-fun/metrics/perceptual_path_length.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
|
10 |
+
Architecture for Generative Adversarial Networks". Matches the original
|
11 |
+
implementation by Karras et al. at
|
12 |
+
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
|
13 |
+
|
14 |
+
import copy
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from . import metric_utils
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
# Spherical interpolation of a batch of vectors.
|
22 |
+
def slerp(a, b, t):
|
23 |
+
a = a / a.norm(dim=-1, keepdim=True)
|
24 |
+
b = b / b.norm(dim=-1, keepdim=True)
|
25 |
+
d = (a * b).sum(dim=-1, keepdim=True)
|
26 |
+
p = t * torch.acos(d)
|
27 |
+
c = b - d * a
|
28 |
+
c = c / c.norm(dim=-1, keepdim=True)
|
29 |
+
d = a * torch.cos(p) + c * torch.sin(p)
|
30 |
+
d = d / d.norm(dim=-1, keepdim=True)
|
31 |
+
return d
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
class PPLSampler(torch.nn.Module):
|
36 |
+
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
|
37 |
+
assert space in ['z', 'w']
|
38 |
+
assert sampling in ['full', 'end']
|
39 |
+
super().__init__()
|
40 |
+
self.G = copy.deepcopy(G)
|
41 |
+
self.G_kwargs = G_kwargs
|
42 |
+
self.epsilon = epsilon
|
43 |
+
self.space = space
|
44 |
+
self.sampling = sampling
|
45 |
+
self.crop = crop
|
46 |
+
self.vgg16 = copy.deepcopy(vgg16)
|
47 |
+
|
48 |
+
def forward(self, c):
|
49 |
+
# Generate random latents and interpolation t-values.
|
50 |
+
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
|
51 |
+
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
|
52 |
+
|
53 |
+
# Interpolate in W or Z.
|
54 |
+
if self.space == 'w':
|
55 |
+
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
|
56 |
+
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
|
57 |
+
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
|
58 |
+
else: # space == 'z'
|
59 |
+
zt0 = slerp(z0, z1, t.unsqueeze(1))
|
60 |
+
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
|
61 |
+
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
|
62 |
+
|
63 |
+
# Randomize noise buffers.
|
64 |
+
for name, buf in self.G.named_buffers():
|
65 |
+
if name.endswith('.noise_const'):
|
66 |
+
buf.copy_(torch.randn_like(buf))
|
67 |
+
|
68 |
+
# Generate images.
|
69 |
+
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
|
70 |
+
|
71 |
+
# Center crop.
|
72 |
+
if self.crop:
|
73 |
+
assert img.shape[2] == img.shape[3]
|
74 |
+
c = img.shape[2] // 8
|
75 |
+
img = img[:, :, c*3 : c*7, c*2 : c*6]
|
76 |
+
|
77 |
+
# Downsample to 256x256.
|
78 |
+
factor = self.G.img_resolution // 256
|
79 |
+
if factor > 1:
|
80 |
+
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
|
81 |
+
|
82 |
+
# Scale dynamic range from [-1,1] to [0,255].
|
83 |
+
img = (img + 1) * (255 / 2)
|
84 |
+
if self.G.img_channels == 1:
|
85 |
+
img = img.repeat([1, 3, 1, 1])
|
86 |
+
|
87 |
+
# Evaluate differential LPIPS.
|
88 |
+
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
|
89 |
+
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
|
90 |
+
return dist
|
91 |
+
|
92 |
+
#----------------------------------------------------------------------------
|
93 |
+
|
94 |
+
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
|
95 |
+
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
|
96 |
+
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
|
97 |
+
|
98 |
+
# Setup sampler and labels.
|
99 |
+
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
|
100 |
+
sampler.eval().requires_grad_(False).to(opts.device)
|
101 |
+
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
|
102 |
+
|
103 |
+
# Sampling loop.
|
104 |
+
dist = []
|
105 |
+
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
|
106 |
+
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
107 |
+
progress.update(batch_start)
|
108 |
+
x = sampler(next(c_iter))
|
109 |
+
for src in range(opts.num_gpus):
|
110 |
+
y = x.clone()
|
111 |
+
if opts.num_gpus > 1:
|
112 |
+
torch.distributed.broadcast(y, src=src)
|
113 |
+
dist.append(y)
|
114 |
+
progress.update(num_samples)
|
115 |
+
|
116 |
+
# Compute PPL.
|
117 |
+
if opts.rank != 0:
|
118 |
+
return float('nan')
|
119 |
+
dist = torch.cat(dist)[:num_samples].cpu().numpy()
|
120 |
+
lo = np.percentile(dist, 1, interpolation='lower')
|
121 |
+
hi = np.percentile(dist, 99, interpolation='higher')
|
122 |
+
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
|
123 |
+
return float(ppl)
|
124 |
+
|
125 |
+
#----------------------------------------------------------------------------
|