2Much2Code:) commited on
Commit
ba5dcdc
1 Parent(s): b5e22af
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +7 -6
  3. app.py +288 -0
  4. components/stylegan2-lookbook_style_ipca_c80_n300000_w.npz +3 -0
  5. config.py +72 -0
  6. decomposition.py +402 -0
  7. environment.yml +25 -0
  8. estimators.py +218 -0
  9. models/__init__.py +11 -0
  10. models/biggan/__init__.py +8 -0
  11. models/biggan/pytorch_biggan/.gitignore +110 -0
  12. models/biggan/pytorch_biggan/LICENSE +21 -0
  13. models/biggan/pytorch_biggan/MANIFEST.in +1 -0
  14. models/biggan/pytorch_biggan/README.md +227 -0
  15. models/biggan/pytorch_biggan/full_requirements.txt +5 -0
  16. models/biggan/pytorch_biggan/pytorch_pretrained_biggan/__init__.py +6 -0
  17. models/biggan/pytorch_biggan/pytorch_pretrained_biggan/config.py +70 -0
  18. models/biggan/pytorch_biggan/pytorch_pretrained_biggan/convert_tf_to_pytorch.py +312 -0
  19. models/biggan/pytorch_biggan/pytorch_pretrained_biggan/file_utils.py +249 -0
  20. models/biggan/pytorch_biggan/pytorch_pretrained_biggan/model.py +345 -0
  21. models/biggan/pytorch_biggan/pytorch_pretrained_biggan/utils.py +216 -0
  22. models/biggan/pytorch_biggan/requirements.txt +8 -0
  23. models/biggan/pytorch_biggan/scripts/convert_tf_hub_models.sh +21 -0
  24. models/biggan/pytorch_biggan/scripts/download_tf_hub_models.sh +21 -0
  25. models/biggan/pytorch_biggan/setup.py +69 -0
  26. models/stylegan/__init__.py +17 -0
  27. models/stylegan/model.py +456 -0
  28. models/stylegan/stylegan_tf/LICENSE.txt +410 -0
  29. models/stylegan/stylegan_tf/README.md +232 -0
  30. models/stylegan/stylegan_tf/config.py +18 -0
  31. models/stylegan/stylegan_tf/dataset_tool.py +645 -0
  32. models/stylegan/stylegan_tf/dnnlib/__init__.py +20 -0
  33. models/stylegan/stylegan_tf/dnnlib/submission/__init__.py +9 -0
  34. models/stylegan/stylegan_tf/dnnlib/submission/_internal/run.py +45 -0
  35. models/stylegan/stylegan_tf/dnnlib/submission/run_context.py +99 -0
  36. models/stylegan/stylegan_tf/dnnlib/submission/submit.py +290 -0
  37. models/stylegan/stylegan_tf/dnnlib/tflib/__init__.py +16 -0
  38. models/stylegan/stylegan_tf/dnnlib/tflib/autosummary.py +184 -0
  39. models/stylegan/stylegan_tf/dnnlib/tflib/network.py +591 -0
  40. models/stylegan/stylegan_tf/dnnlib/tflib/optimizer.py +214 -0
  41. models/stylegan/stylegan_tf/dnnlib/tflib/tfutil.py +240 -0
  42. models/stylegan/stylegan_tf/dnnlib/util.py +405 -0
  43. models/stylegan/stylegan_tf/generate_figures.py +161 -0
  44. models/stylegan/stylegan_tf/metrics/__init__.py +8 -0
  45. models/stylegan/stylegan_tf/metrics/frechet_inception_distance.py +72 -0
  46. models/stylegan/stylegan_tf/metrics/linear_separability.py +177 -0
  47. models/stylegan/stylegan_tf/metrics/metric_base.py +142 -0
  48. models/stylegan/stylegan_tf/metrics/perceptual_path_length.py +108 -0
  49. models/stylegan/stylegan_tf/pretrained_example.py +47 -0
  50. models/stylegan/stylegan_tf/run_metrics.py +105 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020 Erik Härkönen
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Guccio AI Designer
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.1.4
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ClothingGAN
3
+ emoji: 👘
4
+ colorFrom: indigo
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 2.9.4
8
  app_file: app.py
9
  pinned: false
10
+ license: cc-by-nc-3.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ssl import ALERT_DESCRIPTION_CLOSE_NOTIFY
2
+ import nltk; nltk.download('wordnet')
3
+
4
+ #@title Load Model
5
+ selected_model = 'lookbook'
6
+
7
+ # Load model
8
+ from IPython.utils import io
9
+ import torch
10
+ import PIL
11
+ import numpy as np
12
+ import ipywidgets as widgets
13
+ from PIL import Image
14
+ import imageio
15
+ from models import get_instrumented_model
16
+ from decomposition import get_or_compute
17
+ from config import Config
18
+ from skimage import img_as_ubyte
19
+ import gradio as gr
20
+ import numpy as np
21
+ from ipywidgets import fixed
22
+
23
+ # Speed up computation
24
+ torch.autograd.set_grad_enabled(False)
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ # Specify model to use
28
+ config = Config(
29
+ model='StyleGAN2',
30
+ layer='style',
31
+ output_class=selected_model,
32
+ components=80,
33
+ use_w=True,
34
+ batch_size=5_000, # style layer quite small
35
+ )
36
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
+
38
+ inst = get_instrumented_model(config.model, config.output_class,
39
+ config.layer, torch.device(device), use_w=config.use_w)
40
+
41
+ path_to_components = get_or_compute(config, inst)
42
+
43
+ model = inst.model
44
+
45
+ comps = np.load(path_to_components)
46
+ lst = comps.files
47
+ latent_dirs = []
48
+ latent_stdevs = []
49
+
50
+ load_activations = False
51
+
52
+ for item in lst:
53
+ if load_activations:
54
+ if item == 'act_comp':
55
+ for i in range(comps[item].shape[0]):
56
+ latent_dirs.append(comps[item][i])
57
+ if item == 'act_stdev':
58
+ for i in range(comps[item].shape[0]):
59
+ latent_stdevs.append(comps[item][i])
60
+ else:
61
+ if item == 'lat_comp':
62
+ for i in range(comps[item].shape[0]):
63
+ latent_dirs.append(comps[item][i])
64
+ if item == 'lat_stdev':
65
+ for i in range(comps[item].shape[0]):
66
+ latent_stdevs.append(comps[item][i])
67
+
68
+
69
+ #@title Define functions
70
+
71
+
72
+ # Taken from https://github.com/alexanderkuk/log-progress
73
+ def log_progress(sequence, every=1, size=None, name='Items'):
74
+ from ipywidgets import IntProgress, HTML, VBox
75
+ from IPython.display import display
76
+
77
+ is_iterator = False
78
+ if size is None:
79
+ try:
80
+ size = len(sequence)
81
+ except TypeError:
82
+ is_iterator = True
83
+ if size is not None:
84
+ if every is None:
85
+ if size <= 200:
86
+ every = 1
87
+ else:
88
+ every = int(size / 200) # every 0.5%
89
+ else:
90
+ assert every is not None, 'sequence is iterator, set every'
91
+
92
+ if is_iterator:
93
+ progress = IntProgress(min=0, max=1, value=1)
94
+ progress.bar_style = 'info'
95
+ else:
96
+ progress = IntProgress(min=0, max=size, value=0)
97
+ label = HTML()
98
+ box = VBox(children=[label, progress])
99
+ display(box)
100
+
101
+ index = 0
102
+ try:
103
+ for index, record in enumerate(sequence, 1):
104
+ if index == 1 or index % every == 0:
105
+ if is_iterator:
106
+ label.value = '{name}: {index} / ?'.format(
107
+ name=name,
108
+ index=index
109
+ )
110
+ else:
111
+ progress.value = index
112
+ label.value = u'{name}: {index} / {size}'.format(
113
+ name=name,
114
+ index=index,
115
+ size=size
116
+ )
117
+ yield record
118
+ except:
119
+ progress.bar_style = 'danger'
120
+ raise
121
+ else:
122
+ progress.bar_style = 'success'
123
+ progress.value = index
124
+ label.value = "{name}: {index}".format(
125
+ name=name,
126
+ index=str(index or '?')
127
+ )
128
+
129
+ def name_direction(sender):
130
+ if not text.value:
131
+ print('Please name the direction before saving')
132
+ return
133
+
134
+ if num in named_directions.values():
135
+ target_key = list(named_directions.keys())[list(named_directions.values()).index(num)]
136
+ print(f'Direction already named: {target_key}')
137
+ print(f'Overwriting... ')
138
+ del(named_directions[target_key])
139
+ named_directions[text.value] = [num, start_layer.value, end_layer.value]
140
+ save_direction(random_dir, text.value)
141
+ for item in named_directions:
142
+ print(item, named_directions[item])
143
+
144
+ def save_direction(direction, filename):
145
+ filename += ".npy"
146
+ np.save(filename, direction, allow_pickle=True, fix_imports=True)
147
+ print(f'Latent direction saved as {filename}')
148
+
149
+ def mix_w(w1, w2, content, style):
150
+ for i in range(0,5):
151
+ w2[i] = w1[i] * (1 - content) + w2[i] * content
152
+
153
+ for i in range(5, 16):
154
+ w2[i] = w1[i] * (1 - style) + w2[i] * style
155
+
156
+ return w2
157
+
158
+ def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):
159
+ # blockPrint()
160
+ model.truncation = truncation
161
+ if w is None:
162
+ w = model.sample_latent(1, seed=seed).detach().cpu().numpy()
163
+ w = [w]*model.get_max_latents() # one per layer
164
+ else:
165
+ w = [np.expand_dims(x, 0) for x in w]
166
+
167
+ for l in range(start, end):
168
+ for i in range(len(directions)):
169
+ w[l] = w[l] + directions[i] * distances[i] * scale
170
+
171
+ torch.cuda.empty_cache()
172
+ #save image and display
173
+ out = model.sample_np(w)
174
+ final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)
175
+
176
+
177
+ if save is not None:
178
+ if disp == False:
179
+ print(save)
180
+ final_im.save(f'out/{seed}_{save:05}.png')
181
+ if disp:
182
+ display(final_im)
183
+
184
+ return final_im
185
+
186
+ def generate_mov(seed, truncation, direction_vec, scale, layers, n_frames, out_name = 'out', noise_spec = None, loop=True):
187
+ """Generates a mov moving back and forth along the chosen direction vector"""
188
+ # Example of reading a generated set of images, and storing as MP4.
189
+ movieName = f'{out_name}.mp4'
190
+ offset = -10
191
+ step = 20 / n_frames
192
+ imgs = []
193
+ for i in log_progress(range(n_frames), name = "Generating frames"):
194
+ print(f'\r{i} / {n_frames}', end='')
195
+ w = model.sample_latent(1, seed=seed).cpu().numpy()
196
+
197
+ model.truncation = truncation
198
+ w = [w]*model.get_max_latents() # one per layer
199
+ for l in layers:
200
+ if l <= model.get_max_latents():
201
+ w[l] = w[l] + direction_vec * offset * scale
202
+
203
+ #save image and display
204
+ out = model.sample_np(w)
205
+ final_im = Image.fromarray((out * 255).astype(np.uint8))
206
+ imgs.append(out)
207
+ #increase offset
208
+ offset += step
209
+ if loop:
210
+ imgs += imgs[::-1]
211
+ with imageio.get_writer(movieName, mode='I') as writer:
212
+ for image in log_progress(list(imgs), name = "Creating animation"):
213
+ writer.append_data(img_as_ubyte(image))
214
+
215
+
216
+ #@title Demo UI
217
+
218
+
219
+ def generate_image(seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer):
220
+ seed1 = int(seed1)
221
+ seed2 = int(seed2)
222
+
223
+ scale = 1
224
+ params = {'c0': c0,
225
+ 'c1': c1,
226
+ 'c2': c2,
227
+ 'c3': c3,
228
+ 'c4': c4,
229
+ 'c5': c5,
230
+ 'c6': c6}
231
+
232
+ param_indexes = {'c0': 0,
233
+ 'c1': 1,
234
+ 'c2': 2,
235
+ 'c3': 3,
236
+ 'c4': 4,
237
+ 'c5': 5,
238
+ 'c6': 6}
239
+
240
+ directions = []
241
+ distances = []
242
+ for k, v in params.items():
243
+ directions.append(latent_dirs[param_indexes[k]])
244
+ distances.append(v)
245
+
246
+ w1 = model.sample_latent(1, seed=seed1).detach().cpu().numpy()
247
+ w1 = [w1]*model.get_max_latents() # one per layer
248
+ im1 = model.sample_np(w1)
249
+
250
+ w2 = model.sample_latent(1, seed=seed2).detach().cpu().numpy()
251
+ w2 = [w2]*model.get_max_latents() # one per layer
252
+ im2 = model.sample_np(w2)
253
+ combined_im = np.concatenate([im1, im2], axis=1)
254
+ input_im = Image.fromarray((combined_im * 255).astype(np.uint8))
255
+
256
+
257
+ mixed_w = mix_w(w1, w2, content, style)
258
+ return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False)
259
+
260
+ truncation = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label="Truncation")
261
+ start_layer = gr.inputs.Number(default=3, label="Start Layer")
262
+ end_layer = gr.inputs.Number(default=14, label="End Layer")
263
+ seed1 = gr.inputs.Number(default=0, label="Seed 1")
264
+ seed2 = gr.inputs.Number(default=0, label="Seed 2")
265
+ content = gr.inputs.Slider(label="Structure", minimum=0, maximum=1, default=0.5)
266
+ style = gr.inputs.Slider(label="Style", minimum=0, maximum=1, default=0.5)
267
+
268
+ slider_max_val = 20
269
+ slider_min_val = -20
270
+ slider_step = 1
271
+
272
+ c0 = gr.inputs.Slider(label="Sleeve & Size", minimum=slider_min_val, maximum=slider_max_val, default=0)
273
+ c1 = gr.inputs.Slider(label="Dress - Jacket", minimum=slider_min_val, maximum=slider_max_val, default=0)
274
+ c2 = gr.inputs.Slider(label="Female Coat", minimum=slider_min_val, maximum=slider_max_val, default=0)
275
+ c3 = gr.inputs.Slider(label="Coat", minimum=slider_min_val, maximum=slider_max_val, default=0)
276
+ c4 = gr.inputs.Slider(label="Graphics", minimum=slider_min_val, maximum=slider_max_val, default=0)
277
+ c5 = gr.inputs.Slider(label="Dark", minimum=slider_min_val, maximum=slider_max_val, default=0)
278
+ c6 = gr.inputs.Slider(label="Less Cleavage", minimum=slider_min_val, maximum=slider_max_val, default=0)
279
+
280
+
281
+ scale = 1
282
+
283
+ inputs = [seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer]
284
+ description = "Change the seed number to generate different parent design.Please give a clap/star if you find it useful :)"
285
+
286
+ article="<p style='text-align: center'>Made by <a href='https://github.com/AdiNarendra98/AI-for-Fashion/tree/main/Guccio%20AI(Clothing%20Design%20Generator)' target='_blank'>@AdiNarendra with 🖤 </a>.Thanks to @mfrashad for the inspiration for this.</p> "
287
+
288
+ gr.Interface(generate_image, inputs, ["image", "image"], description=description, live=True,article=article,title="ClothingGAN").launch()
components/stylegan2-lookbook_style_ipca_c80_n300000_w.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc8dd611f9eba549338aaac546bc11dfda01ced79456ee0bb63387adf997bde1
3
+ size 312337
config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Erik Härkönen. All rights reserved.
2
+ # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License. You may obtain a copy
4
+ # of the License at http://www.apache.org/licenses/LICENSE-2.0
5
+
6
+ # Unless required by applicable law or agreed to in writing, software distributed under
7
+ # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8
+ # OF ANY KIND, either express or implied. See the License for the specific language
9
+ # governing permissions and limitations under the License.
10
+
11
+ import sys
12
+ import argparse
13
+ import json
14
+ from copy import deepcopy
15
+
16
+ class Config:
17
+ def __init__(self, **kwargs):
18
+ self.from_args([]) # set all defaults
19
+ self.default_args = deepcopy(self.__dict__)
20
+ self.from_dict(kwargs) # override
21
+
22
+ def __str__(self):
23
+ custom = {}
24
+ default = {}
25
+
26
+ # Find non-default arguments
27
+ for k, v in self.__dict__.items():
28
+ if k == 'default_args':
29
+ continue
30
+
31
+ in_default = k in self.default_args
32
+ same_value = self.default_args.get(k) == v
33
+
34
+ if in_default and same_value:
35
+ default[k] = v
36
+ else:
37
+ custom[k] = v
38
+
39
+ config = {
40
+ 'custom': custom,
41
+ 'default': default
42
+ }
43
+
44
+ return json.dumps(config, indent=4)
45
+
46
+ def __repr__(self):
47
+ return self.__str__()
48
+
49
+ def from_dict(self, dictionary):
50
+ for k, v in dictionary.items():
51
+ setattr(self, k, v)
52
+ return self
53
+
54
+ def from_args(self, args=sys.argv[1:]):
55
+ parser = argparse.ArgumentParser(description='GAN component analysis config')
56
+ parser.add_argument('--model', dest='model', type=str, default='StyleGAN', help='The network to analyze') # StyleGAN, DCGAN, ProGAN, BigGAN-XYZ
57
+ parser.add_argument('--layer', dest='layer', type=str, default='g_mapping', help='The layer to analyze')
58
+ parser.add_argument('--class', dest='output_class', type=str, default=None, help='Output class to generate (BigGAN: Imagenet, ProGAN: LSUN)')
59
+ parser.add_argument('--est', dest='estimator', type=str, default='ipca', help='The algorithm to use [pca, fbpca, cupca, spca, ica]')
60
+ parser.add_argument('--sparsity', type=float, default=1.0, help='Sparsity parameter of SPCA')
61
+ parser.add_argument('--video', dest='make_video', action='store_true', help='Generate output videos (MP4s)')
62
+ parser.add_argument('--batch', dest='batch_mode', action='store_true', help="Don't open windows, instead save results to file")
63
+ parser.add_argument('-b', dest='batch_size', type=int, default=None, help='Minibatch size, leave empty for automatic detection')
64
+ parser.add_argument('-c', dest='components', type=int, default=80, help='Number of components to keep')
65
+ parser.add_argument('-n', type=int, default=300_000, help='Number of examples to use in decomposition')
66
+ parser.add_argument('--use_w', action='store_true', help='Use W latent space (StyleGAN(2))')
67
+ parser.add_argument('--sigma', type=float, default=2.0, help='Number of stdevs to walk in visualize.py')
68
+ parser.add_argument('--inputs', type=str, default=None, help='Path to directory with named components')
69
+ parser.add_argument('--seed', type=int, default=None, help='Seed used in decomposition')
70
+ args = parser.parse_args(args)
71
+
72
+ return self.from_dict(args.__dict__)
decomposition.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Erik Härkönen. All rights reserved.
2
+ # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License. You may obtain a copy
4
+ # of the License at http://www.apache.org/licenses/LICENSE-2.0
5
+
6
+ # Unless required by applicable law or agreed to in writing, software distributed under
7
+ # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8
+ # OF ANY KIND, either express or implied. See the License for the specific language
9
+ # governing permissions and limitations under the License.
10
+
11
+ # Patch for broken CTRL+C handler
12
+ # https://github.com/ContinuumIO/anaconda-issues/issues/905
13
+ import os
14
+ os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'
15
+
16
+ import numpy as np
17
+ import os
18
+ from pathlib import Path
19
+ import re
20
+ import sys
21
+ import datetime
22
+ import argparse
23
+ import torch
24
+ import json
25
+ from types import SimpleNamespace
26
+ import scipy
27
+ from scipy.cluster.vq import kmeans
28
+ from tqdm import trange
29
+ from netdissect.nethook import InstrumentedModel
30
+ from config import Config
31
+ from estimators import get_estimator
32
+ from models import get_instrumented_model
33
+
34
+ SEED_SAMPLING = 1
35
+ SEED_RANDOM_DIRS = 2
36
+ SEED_LINREG = 3
37
+ SEED_VISUALIZATION = 5
38
+
39
+ B = 20
40
+ n_clusters = 500
41
+
42
+ def get_random_dirs(components, dimensions):
43
+ gen = np.random.RandomState(seed=SEED_RANDOM_DIRS)
44
+ dirs = gen.normal(size=(components, dimensions))
45
+ dirs /= np.sqrt(np.sum(dirs**2, axis=1, keepdims=True))
46
+ return dirs.astype(np.float32)
47
+
48
+ # Compute maximum batch size for given VRAM and network
49
+ def get_max_batch_size(inst, device, layer_name=None):
50
+ inst.remove_edits()
51
+
52
+ # Reset statistics
53
+ torch.cuda.reset_max_memory_cached(device)
54
+ torch.cuda.reset_max_memory_allocated(device)
55
+ total_mem = torch.cuda.get_device_properties(device).total_memory
56
+
57
+ B_max = 20
58
+
59
+ # Measure actual usage
60
+ for i in range(2, B_max, 2):
61
+ z = inst.model.sample_latent(n_samples=i)
62
+ if layer_name:
63
+ inst.model.partial_forward(z, layer_name)
64
+ else:
65
+ inst.model.forward(z)
66
+
67
+ maxmem = torch.cuda.max_memory_allocated(device)
68
+ del z
69
+
70
+ if maxmem > 0.5*total_mem:
71
+ print('Batch size {:d}: memory usage {:.0f}MB'.format(i, maxmem / 1e6))
72
+ return i
73
+
74
+ return B_max
75
+
76
+ # Solve for directions in latent space that match PCs in activaiton space
77
+ def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config):
78
+ print('Performing least squares regression', flush=True)
79
+
80
+ torch.manual_seed(SEED_LINREG)
81
+ np.random.seed(SEED_LINREG)
82
+
83
+ comp = torch.from_numpy(comp_np).float().to(inst.model.device)
84
+ mean = torch.from_numpy(mean_np).float().to(inst.model.device)
85
+ stdev = torch.from_numpy(stdev_np).float().to(inst.model.device)
86
+
87
+ n_samp = max(10_000, config.n) // B * B # make divisible
88
+ n_comp = comp.shape[0]
89
+ latent_dims = inst.model.get_latent_dims()
90
+
91
+ # We're looking for M s.t. M*P*G'(Z) = Z => M*A = Z
92
+ # Z = batch of latent vectors (n_samples x latent_dims)
93
+ # G'(Z) = batch of activations at intermediate layer
94
+ # A = P*G'(Z) = projected activations (n_samples x pca_coords)
95
+ # M = linear mapping (pca_coords x latent_dims)
96
+
97
+ # Minimization min_M ||MA - Z||_l2 rewritten as min_M.T ||A.T*M.T - Z.T||_l2
98
+ # to match format expected by pytorch.lstsq
99
+
100
+ # TODO: regression on pixel-space outputs? (using nonlinear optimizer)
101
+ # min_M lpips(G_full(MA), G_full(Z))
102
+
103
+ # Tensors to fill with data
104
+ # Dimensions other way around, so these are actually the transposes
105
+ A = np.zeros((n_samp, n_comp), dtype=np.float32)
106
+ Z = np.zeros((n_samp, latent_dims), dtype=np.float32)
107
+
108
+ # Project tensor X onto PCs, return coordinates
109
+ def project(X, comp):
110
+ N = X.shape[0]
111
+ K = comp.shape[0]
112
+ coords = torch.bmm(comp.expand([N]+[-1]*comp.ndim), X.view(N, -1, 1))
113
+ return coords.reshape(N, K)
114
+
115
+ for i in trange(n_samp // B, desc='Collecting samples', ascii=True):
116
+ z = inst.model.sample_latent(B)
117
+ inst.model.partial_forward(z, config.layer)
118
+ act = inst.retained_features()[config.layer].reshape(B, -1)
119
+
120
+ # Project onto basis
121
+ act = act - mean
122
+ coords = project(act, comp)
123
+ coords_scaled = coords / stdev
124
+
125
+ A[i*B:(i+1)*B] = coords_scaled.detach().cpu().numpy()
126
+ Z[i*B:(i+1)*B] = z.detach().cpu().numpy().reshape(B, -1)
127
+
128
+ # Solve least squares fit
129
+
130
+ # gelsd = divide-and-conquer SVD; good default
131
+ # gelsy = complete orthogonal factorization; sometimes faster
132
+ # gelss = SVD; slow but less memory hungry
133
+ M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] # torch.lstsq(Z, A)[0][:n_comp, :]
134
+
135
+ # Solution given by rows of M_t
136
+ Z_comp = M_t[:n_comp, :]
137
+ Z_mean = np.mean(Z, axis=0, keepdims=True)
138
+
139
+ return Z_comp, Z_mean
140
+
141
+ def regression(comp, mean, stdev, inst, config):
142
+ # Sanity check: verify orthonormality
143
+ M = np.dot(comp, comp.T)
144
+ if not np.allclose(M, np.identity(M.shape[0])):
145
+ det = np.linalg.det(M)
146
+ print(f'WARNING: Computed basis is not orthonormal (determinant={det})')
147
+
148
+ return linreg_lstsq(comp, mean, stdev, inst, config)
149
+
150
+ def compute(config, dump_name, instrumented_model):
151
+ global B
152
+
153
+ timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M")
154
+ print(f'[{timestamp()}] Computing', dump_name.name)
155
+
156
+ # Ensure reproducibility
157
+ torch.manual_seed(0) # also sets cuda seeds
158
+ np.random.seed(0)
159
+
160
+ # Speed up backend
161
+ torch.backends.cudnn.benchmark = True
162
+
163
+ has_gpu = torch.cuda.is_available()
164
+ device = torch.device('cuda' if has_gpu else 'cpu')
165
+ layer_key = config.layer
166
+
167
+ if instrumented_model is None:
168
+ inst = get_instrumented_model(config.model, config.output_class, layer_key, device)
169
+ model = inst.model
170
+ else:
171
+ print('Reusing InstrumentedModel instance')
172
+ inst = instrumented_model
173
+ model = inst.model
174
+ inst.remove_edits()
175
+ model.set_output_class(config.output_class)
176
+
177
+ # Regress back to w space
178
+ if config.use_w:
179
+ print('Using W latent space')
180
+ model.use_w()
181
+
182
+ inst.retain_layer(layer_key)
183
+ model.partial_forward(model.sample_latent(1), layer_key)
184
+ sample_shape = inst.retained_features()[layer_key].shape
185
+ sample_dims = np.prod(sample_shape)
186
+ print('Feature shape:', sample_shape)
187
+
188
+ input_shape = inst.model.get_latent_shape()
189
+ input_dims = inst.model.get_latent_dims()
190
+
191
+ config.components = min(config.components, sample_dims)
192
+ transformer = get_estimator(config.estimator, config.components, config.sparsity)
193
+
194
+ X = None
195
+ X_global_mean = None
196
+
197
+ # Figure out batch size if not provided
198
+ B = config.batch_size or get_max_batch_size(inst, device, layer_key)
199
+
200
+ # Divisible by B (ignored in output name)
201
+ N = config.n // B * B
202
+
203
+ # Compute maximum batch size based on RAM + pagefile budget
204
+ target_bytes = 20 * 1_000_000_000 # GB
205
+ feat_size_bytes = sample_dims * np.dtype('float64').itemsize
206
+ N_limit_RAM = np.floor_divide(target_bytes, feat_size_bytes)
207
+ if not transformer.batch_support and N > N_limit_RAM:
208
+ print('WARNING: estimator does not support batching, ' \
209
+ 'given config will use {:.1f} GB memory.'.format(feat_size_bytes / 1_000_000_000 * N))
210
+
211
+ # 32-bit LAPACK gets very unhappy about huge matrices (in linalg.svd)
212
+ if config.estimator == 'ica':
213
+ lapack_max_N = np.floor_divide(np.iinfo(np.int32).max // 4, sample_dims) # 4x extra buffer
214
+ if N > lapack_max_N:
215
+ raise RuntimeError(f'Matrices too large for ICA, please use N <= {lapack_max_N}')
216
+
217
+ print('B={}, N={}, dims={}, N/dims={:.1f}'.format(B, N, sample_dims, N/sample_dims), flush=True)
218
+
219
+ # Must not depend on chosen batch size (reproducibility)
220
+ NB = max(B, max(2_000, 3*config.components)) # ipca: as large as possible!
221
+
222
+ samples = None
223
+ if not transformer.batch_support:
224
+ samples = np.zeros((N + NB, sample_dims), dtype=np.float32)
225
+
226
+ torch.manual_seed(config.seed or SEED_SAMPLING)
227
+ np.random.seed(config.seed or SEED_SAMPLING)
228
+
229
+ # Use exactly the same latents regardless of batch size
230
+ # Store in main memory, since N might be huge (1M+)
231
+ # Run in batches, since sample_latent() might perform Z -> W mapping
232
+ n_lat = ((N + NB - 1) // B + 1) * B
233
+ latents = np.zeros((n_lat, *input_shape[1:]), dtype=np.float32)
234
+ with torch.no_grad():
235
+ for i in trange(n_lat // B, desc='Sampling latents'):
236
+ latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy()
237
+
238
+ # Decomposition on non-Gaussian latent space
239
+ samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W'
240
+
241
+ canceled = False
242
+ try:
243
+ X = np.ones((NB, sample_dims), dtype=np.float32)
244
+ action = 'Fitting' if transformer.batch_support else 'Collecting'
245
+ for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True):
246
+ for mb in range(0, NB, B):
247
+ z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device)
248
+
249
+ if samples_are_latents:
250
+ # Decomposition on latents directly (e.g. StyleGAN W)
251
+ batch = z.reshape((B, -1))
252
+ else:
253
+ # Decomposition on intermediate layer
254
+ with torch.no_grad():
255
+ model.partial_forward(z, layer_key)
256
+
257
+ # Permuted to place PCA dimensions last
258
+ batch = inst.retained_features()[layer_key].reshape((B, -1))
259
+
260
+ space_left = min(B, NB - mb)
261
+ X[mb:mb+space_left] = batch.cpu().numpy()[:space_left]
262
+
263
+ if transformer.batch_support:
264
+ if not transformer.fit_partial(X.reshape(-1, sample_dims)):
265
+ break
266
+ else:
267
+ samples[gi:gi+NB, :] = X.copy()
268
+ except KeyboardInterrupt:
269
+ if not transformer.batch_support:
270
+ sys.exit(1) # no progress yet
271
+
272
+ dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}')
273
+ print(f'Saving current state to "{dump_name.name}" before exiting')
274
+ canceled = True
275
+
276
+ if not transformer.batch_support:
277
+ X = samples # Use all samples
278
+ X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) # TODO: activations surely multi-modal...!
279
+ X -= X_global_mean
280
+
281
+ print(f'[{timestamp()}] Fitting whole batch')
282
+ t_start_fit = datetime.datetime.now()
283
+
284
+ transformer.fit(X)
285
+
286
+ print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}')
287
+ assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero'
288
+ else:
289
+ X_global_mean = transformer.transformer.mean_.reshape((1, sample_dims))
290
+ X = X.reshape(-1, sample_dims)
291
+ X -= X_global_mean
292
+
293
+ X_comp, X_stdev, X_var_ratio = transformer.get_components()
294
+
295
+ assert X_comp.shape[1] == sample_dims \
296
+ and X_comp.shape[0] == config.components \
297
+ and X_global_mean.shape[1] == sample_dims \
298
+ and X_stdev.shape[0] == config.components, 'Invalid shape'
299
+
300
+ # 'Activations' are really latents in a secondary latent space
301
+ if samples_are_latents:
302
+ Z_comp = X_comp
303
+ Z_global_mean = X_global_mean
304
+ else:
305
+ Z_comp, Z_global_mean = regression(X_comp, X_global_mean, X_stdev, inst, config)
306
+
307
+ # Normalize
308
+ Z_comp /= np.linalg.norm(Z_comp, axis=-1, keepdims=True)
309
+
310
+ # Random projections
311
+ # We expect these to explain much less of the variance
312
+ random_dirs = get_random_dirs(config.components, np.prod(sample_shape))
313
+ n_rand_samples = min(5000, X.shape[0])
314
+ X_view = X[:n_rand_samples, :].T
315
+ assert np.shares_memory(X_view, X), "Error: slice produced copy"
316
+ X_stdev_random = np.dot(random_dirs, X_view).std(axis=1)
317
+
318
+ # Inflate back to proper shapes (for easier broadcasting)
319
+ X_comp = X_comp.reshape(-1, *sample_shape)
320
+ X_global_mean = X_global_mean.reshape(sample_shape)
321
+ Z_comp = Z_comp.reshape(-1, *input_shape)
322
+ Z_global_mean = Z_global_mean.reshape(input_shape)
323
+
324
+ # Compute stdev in latent space if non-Gaussian
325
+ lat_stdev = np.ones_like(X_stdev)
326
+ if config.use_w:
327
+ samples = model.sample_latent(5000).reshape(5000, input_dims).detach().cpu().numpy()
328
+ coords = np.dot(Z_comp.reshape(-1, input_dims), samples.T)
329
+ lat_stdev = coords.std(axis=1)
330
+
331
+ os.makedirs(dump_name.parent, exist_ok=True)
332
+ np.savez_compressed(dump_name, **{
333
+ 'act_comp': X_comp.astype(np.float32),
334
+ 'act_mean': X_global_mean.astype(np.float32),
335
+ 'act_stdev': X_stdev.astype(np.float32),
336
+ 'lat_comp': Z_comp.astype(np.float32),
337
+ 'lat_mean': Z_global_mean.astype(np.float32),
338
+ 'lat_stdev': lat_stdev.astype(np.float32),
339
+ 'var_ratio': X_var_ratio.astype(np.float32),
340
+ 'random_stdevs': X_stdev_random.astype(np.float32),
341
+ })
342
+
343
+ if canceled:
344
+ sys.exit(1)
345
+
346
+ # Don't shutdown if passed as param
347
+ if instrumented_model is None:
348
+ inst.close()
349
+ del inst
350
+ del model
351
+
352
+ del X
353
+ del X_comp
354
+ del random_dirs
355
+ del batch
356
+ del samples
357
+ del latents
358
+ torch.cuda.empty_cache()
359
+
360
+ # Return cached results or commpute if needed
361
+ # Pass existing InstrumentedModel instance to reuse it
362
+ def get_or_compute(config, model=None, submit_config=None, force_recompute=False):
363
+ if submit_config is None:
364
+ wrkdir = str(Path(__file__).parent.resolve())
365
+ submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir)
366
+
367
+ # Called directly by run.py
368
+ return _compute(submit_config, config, model, force_recompute)
369
+
370
+ def _compute(submit_config, config, model=None, force_recompute=False):
371
+ basedir = Path(submit_config.run_dir)
372
+ outdir = basedir / 'out'
373
+
374
+ if config.n is None:
375
+ raise RuntimeError('Must specify number of samples with -n=XXX')
376
+
377
+ if model and not isinstance(model, InstrumentedModel):
378
+ raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"')
379
+
380
+ if config.use_w and not 'StyleGAN' in config.model:
381
+ raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}')
382
+
383
+ transformer = get_estimator(config.estimator, config.components, config.sparsity)
384
+ dump_name = "{}-{}_{}_{}_n{}{}{}.npz".format(
385
+ config.model.lower(),
386
+ config.output_class.replace(' ', '_'),
387
+ config.layer.lower(),
388
+ transformer.get_param_str(),
389
+ config.n,
390
+ '_w' if config.use_w else '',
391
+ f'_seed{config.seed}' if config.seed else ''
392
+ )
393
+
394
+ dump_path = basedir / 'cache' / 'components' / dump_name
395
+
396
+ if not dump_path.is_file() or force_recompute:
397
+ print('Not cached')
398
+ t_start = datetime.datetime.now()
399
+ compute(config, dump_path, model)
400
+ print('Total time:', datetime.datetime.now() - t_start)
401
+
402
+ return dump_path
environment.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ganspace
2
+ channels:
3
+ - defaults
4
+ - conda-forge
5
+ - pytorch
6
+ dependencies:
7
+ - python=3.7
8
+ - pytorch::pytorch=1.3
9
+ - pytorch::torchvision
10
+ - cudatoolkit=10.1
11
+ - pillow=6.2
12
+ - ffmpeg
13
+ - tqdm
14
+ - scipy
15
+ - scikit-learn
16
+ - scikit-image
17
+ - boto3
18
+ - requests
19
+ - nltk
20
+ - pip
21
+ - pip:
22
+ - fbpca
23
+ - pyopengltk
24
+
25
+ # conda env update -f environment.yml --prune
estimators.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Erik Härkönen. All rights reserved.
2
+ # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License. You may obtain a copy
4
+ # of the License at http://www.apache.org/licenses/LICENSE-2.0
5
+
6
+ # Unless required by applicable law or agreed to in writing, software distributed under
7
+ # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8
+ # OF ANY KIND, either express or implied. See the License for the specific language
9
+ # governing permissions and limitations under the License.
10
+
11
+ from sklearn.decomposition import FastICA, PCA, IncrementalPCA, MiniBatchSparsePCA, SparsePCA, KernelPCA
12
+ import fbpca
13
+ import numpy as np
14
+ import itertools
15
+ from types import SimpleNamespace
16
+
17
+ # ICA
18
+ class ICAEstimator():
19
+ def __init__(self, n_components):
20
+ self.n_components = n_components
21
+ self.maxiter = 10000
22
+ self.whiten = True # ICA: whitening is essential, should not be skipped
23
+ self.transformer = FastICA(n_components, random_state=0, whiten=self.whiten, max_iter=self.maxiter)
24
+ self.batch_support = False
25
+ self.stdev = np.zeros((n_components,))
26
+ self.total_var = 0.0
27
+
28
+ def get_param_str(self):
29
+ return "ica_c{}{}".format(self.n_components, '_w' if self.whiten else '')
30
+
31
+ def fit(self, X):
32
+ self.transformer.fit(X)
33
+ if self.transformer.n_iter_ >= self.maxiter:
34
+ raise RuntimeError(f'FastICA did not converge (N={X.shape[0]}, it={self.maxiter})')
35
+
36
+ # Normalize components
37
+ self.transformer.components_ /= np.sqrt(np.sum(self.transformer.components_**2, axis=-1, keepdims=True))
38
+
39
+ # Save variance for later
40
+ self.total_var = X.var(axis=0).sum()
41
+
42
+ # Compute projected standard deviations
43
+ self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1)
44
+
45
+ # Sort components based on explained variance
46
+ idx = np.argsort(self.stdev)[::-1]
47
+ self.stdev = self.stdev[idx]
48
+ self.transformer.components_[:] = self.transformer.components_[idx]
49
+
50
+ def get_components(self):
51
+ var_ratio = self.stdev**2 / self.total_var
52
+ return self.transformer.components_, self.stdev, var_ratio # ICA outputs are not normalized
53
+
54
+ # Incremental PCA
55
+ class IPCAEstimator():
56
+ def __init__(self, n_components):
57
+ self.n_components = n_components
58
+ self.whiten = False
59
+ self.transformer = IncrementalPCA(n_components, whiten=self.whiten, batch_size=max(100, 2*n_components))
60
+ self.batch_support = True
61
+
62
+ def get_param_str(self):
63
+ return "ipca_c{}{}".format(self.n_components, '_w' if self.whiten else '')
64
+
65
+ def fit(self, X):
66
+ self.transformer.fit(X)
67
+
68
+ def fit_partial(self, X):
69
+ try:
70
+ self.transformer.partial_fit(X)
71
+ self.transformer.n_samples_seen_ = \
72
+ self.transformer.n_samples_seen_.astype(np.int64) # avoid overflow
73
+ return True
74
+ except ValueError as e:
75
+ print(f'\nIPCA error:', e)
76
+ return False
77
+
78
+ def get_components(self):
79
+ stdev = np.sqrt(self.transformer.explained_variance_) # already sorted
80
+ var_ratio = self.transformer.explained_variance_ratio_
81
+ return self.transformer.components_, stdev, var_ratio # PCA outputs are normalized
82
+
83
+ # Standard PCA
84
+ class PCAEstimator():
85
+ def __init__(self, n_components):
86
+ self.n_components = n_components
87
+ self.solver = 'full'
88
+ self.transformer = PCA(n_components, svd_solver=self.solver)
89
+ self.batch_support = False
90
+
91
+ def get_param_str(self):
92
+ return f"pca-{self.solver}_c{self.n_components}"
93
+
94
+ def fit(self, X):
95
+ self.transformer.fit(X)
96
+
97
+ # Save variance for later
98
+ self.total_var = X.var(axis=0).sum()
99
+
100
+ # Compute projected standard deviations
101
+ self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1)
102
+
103
+ # Sort components based on explained variance
104
+ idx = np.argsort(self.stdev)[::-1]
105
+ self.stdev = self.stdev[idx]
106
+ self.transformer.components_[:] = self.transformer.components_[idx]
107
+
108
+ # Check orthogonality
109
+ dotps = [np.dot(*self.transformer.components_[[i, j]])
110
+ for (i, j) in itertools.combinations(range(self.n_components), 2)]
111
+ if not np.allclose(dotps, 0, atol=1e-4):
112
+ print('IPCA components not orghogonal, max dot', np.abs(dotps).max())
113
+
114
+ self.transformer.mean_ = X.mean(axis=0, keepdims=True)
115
+
116
+ def get_components(self):
117
+ var_ratio = self.stdev**2 / self.total_var
118
+ return self.transformer.components_, self.stdev, var_ratio
119
+
120
+ # Facebook's PCA
121
+ # Good default choice: very fast and accurate.
122
+ # Very high sample counts won't fit into RAM,
123
+ # in which case IncrementalPCA must be used.
124
+ class FacebookPCAEstimator():
125
+ def __init__(self, n_components):
126
+ self.n_components = n_components
127
+ self.transformer = SimpleNamespace()
128
+ self.batch_support = False
129
+ self.n_iter = 2
130
+ self.l = 2*self.n_components
131
+
132
+ def get_param_str(self):
133
+ return "fbpca_c{}_it{}_l{}".format(self.n_components, self.n_iter, self.l)
134
+
135
+ def fit(self, X):
136
+ U, s, Va = fbpca.pca(X, k=self.n_components, n_iter=self.n_iter, raw=True, l=self.l)
137
+ self.transformer.components_ = Va
138
+
139
+ # Save variance for later
140
+ self.total_var = X.var(axis=0).sum()
141
+
142
+ # Compute projected standard deviations
143
+ self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1)
144
+
145
+ # Sort components based on explained variance
146
+ idx = np.argsort(self.stdev)[::-1]
147
+ self.stdev = self.stdev[idx]
148
+ self.transformer.components_[:] = self.transformer.components_[idx]
149
+
150
+ # Check orthogonality
151
+ dotps = [np.dot(*self.transformer.components_[[i, j]])
152
+ for (i, j) in itertools.combinations(range(self.n_components), 2)]
153
+ if not np.allclose(dotps, 0, atol=1e-4):
154
+ print('FBPCA components not orghogonal, max dot', np.abs(dotps).max())
155
+
156
+ self.transformer.mean_ = X.mean(axis=0, keepdims=True)
157
+
158
+ def get_components(self):
159
+ var_ratio = self.stdev**2 / self.total_var
160
+ return self.transformer.components_, self.stdev, var_ratio
161
+
162
+ # Sparse PCA
163
+ # The algorithm is online along the features direction, not the samples direction
164
+ # => no partial_fit
165
+ class SPCAEstimator():
166
+ def __init__(self, n_components, alpha=10.0):
167
+ self.n_components = n_components
168
+ self.whiten = False
169
+ self.alpha = alpha # higher alpha => sparser components
170
+ #self.transformer = MiniBatchSparsePCA(n_components, alpha=alpha, n_iter=100,
171
+ # batch_size=max(20, n_components//5), random_state=0, normalize_components=True)
172
+ self.transformer = SparsePCA(n_components, alpha=alpha, ridge_alpha=0.01,
173
+ max_iter=100, random_state=0, n_jobs=-1, normalize_components=True) # TODO: warm start using PCA result?
174
+ self.batch_support = False # maybe through memmap and HDD-stored tensor
175
+ self.stdev = np.zeros((n_components,))
176
+ self.total_var = 0.0
177
+
178
+ def get_param_str(self):
179
+ return "spca_c{}_a{}{}".format(self.n_components, self.alpha, '_w' if self.whiten else '')
180
+
181
+ def fit(self, X):
182
+ self.transformer.fit(X)
183
+
184
+ # Save variance for later
185
+ self.total_var = X.var(axis=0).sum()
186
+
187
+ # Compute projected standard deviations
188
+ # NB: cannot simply project with dot product!
189
+ self.stdev = self.transformer.transform(X).std(axis=0) # X = (n_samples, n_features)
190
+
191
+ # Sort components based on explained variance
192
+ idx = np.argsort(self.stdev)[::-1]
193
+ self.stdev = self.stdev[idx]
194
+ self.transformer.components_[:] = self.transformer.components_[idx]
195
+
196
+ # Check orthogonality
197
+ dotps = [np.dot(*self.transformer.components_[[i, j]])
198
+ for (i, j) in itertools.combinations(range(self.n_components), 2)]
199
+ if not np.allclose(dotps, 0, atol=1e-4):
200
+ print('SPCA components not orghogonal, max dot', np.abs(dotps).max())
201
+
202
+ def get_components(self):
203
+ var_ratio = self.stdev**2 / self.total_var
204
+ return self.transformer.components_, self.stdev, var_ratio # SPCA outputs are normalized
205
+
206
+ def get_estimator(name, n_components, alpha):
207
+ if name == 'pca':
208
+ return PCAEstimator(n_components)
209
+ if name == 'ipca':
210
+ return IPCAEstimator(n_components)
211
+ elif name == 'fbpca':
212
+ return FacebookPCAEstimator(n_components)
213
+ elif name == 'ica':
214
+ return ICAEstimator(n_components)
215
+ elif name == 'spca':
216
+ return SPCAEstimator(n_components, alpha)
217
+ else:
218
+ raise RuntimeError('Unknown estimator')
models/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Erik Härkönen. All rights reserved.
2
+ # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License. You may obtain a copy
4
+ # of the License at http://www.apache.org/licenses/LICENSE-2.0
5
+
6
+ # Unless required by applicable law or agreed to in writing, software distributed under
7
+ # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8
+ # OF ANY KIND, either express or implied. See the License for the specific language
9
+ # governing permissions and limitations under the License.
10
+
11
+ from .wrappers import *
models/biggan/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+
4
+ module_path = Path(__file__).parent / 'pytorch_biggan'
5
+ sys.path.append(str(module_path.resolve()))
6
+ from pytorch_pretrained_biggan import *
7
+ from pytorch_pretrained_biggan.model import GenBlock
8
+ from pytorch_pretrained_biggan.file_utils import http_get, s3_get
models/biggan/pytorch_biggan/.gitignore ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+
106
+ # vscode
107
+ .vscode/
108
+
109
+ # models
110
+ models/
models/biggan/pytorch_biggan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Erik Härkönen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/biggan/pytorch_biggan/MANIFEST.in ADDED
@@ -0,0 +1 @@
 
 
1
+ include LICENSE
models/biggan/pytorch_biggan/README.md ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BigStyleGAN
2
+ This is a copy of HuggingFace's BigGAN implementation, with the addition of layerwise latent inputs.
3
+
4
+ # PyTorch pretrained BigGAN
5
+ An op-for-op PyTorch reimplementation of DeepMind's BigGAN model with the pre-trained weights from DeepMind.
6
+
7
+ ## Introduction
8
+
9
+ This repository contains an op-for-op PyTorch reimplementation of DeepMind's BigGAN that was released with the paper [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://openreview.net/forum?id=B1xsqj09Fm) by Andrew Brock, Jeff Donahue and Karen Simonyan.
10
+
11
+ This PyTorch implementation of BigGAN is provided with the [pretrained 128x128, 256x256 and 512x512 models by DeepMind](https://tfhub.dev/deepmind/biggan-deep-128/1). We also provide the scripts used to download and convert these models from the TensorFlow Hub models.
12
+
13
+ This reimplementation was done from the raw computation graph of the Tensorflow version and behave similarly to the TensorFlow version (variance of the output difference of the order of 1e-5).
14
+
15
+ This implementation currently only contains the generator as the weights of the discriminator were not released (although the structure of the discriminator is very similar to the generator so it could be added pretty easily. Tell me if you want to do a PR on that, I would be happy to help.)
16
+
17
+ ## Installation
18
+
19
+ This repo was tested on Python 3.6 and PyTorch 1.0.1
20
+
21
+ PyTorch pretrained BigGAN can be installed from pip as follows:
22
+ ```bash
23
+ pip install pytorch-pretrained-biggan
24
+ ```
25
+
26
+ If you simply want to play with the GAN this should be enough.
27
+
28
+ If you want to use the conversion scripts and the imagenet utilities, additional requirements are needed, in particular TensorFlow and NLTK. To install all the requirements please use the `full_requirements.txt` file:
29
+ ```bash
30
+ git clone https://github.com/huggingface/pytorch-pretrained-BigGAN.git
31
+ cd pytorch-pretrained-BigGAN
32
+ pip install -r full_requirements.txt
33
+ ```
34
+
35
+ ## Models
36
+
37
+ This repository provide direct and simple access to the pretrained "deep" versions of BigGAN for 128, 256 and 512 pixels resolutions as described in the [associated publication](https://openreview.net/forum?id=B1xsqj09Fm).
38
+ Here are some details on the models:
39
+
40
+ - `BigGAN-deep-128`: a 50.4M parameters model generating 128x128 pixels images, the model dump weights 201 MB,
41
+ - `BigGAN-deep-256`: a 55.9M parameters model generating 256x256 pixels images, the model dump weights 224 MB,
42
+ - `BigGAN-deep-512`: a 56.2M parameters model generating 512x512 pixels images, the model dump weights 225 MB.
43
+
44
+ Please refer to Appendix B of the paper for details on the architectures.
45
+
46
+ All models comprise pre-computed batch norm statistics for 51 truncation values between 0 and 1 (see Appendix C.1 in the paper for details).
47
+
48
+ ## Usage
49
+
50
+ Here is a quick-start example using `BigGAN` with a pre-trained model.
51
+
52
+ See the [doc section](#doc) below for details on these classes and methods.
53
+
54
+ ```python
55
+ import torch
56
+ from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
57
+ save_as_images, display_in_terminal)
58
+
59
+ # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
60
+ import logging
61
+ logging.basicConfig(level=logging.INFO)
62
+
63
+ # Load pre-trained model tokenizer (vocabulary)
64
+ model = BigGAN.from_pretrained('biggan-deep-256')
65
+
66
+ # Prepare a input
67
+ truncation = 0.4
68
+ class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3)
69
+ noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3)
70
+
71
+ # All in tensors
72
+ noise_vector = torch.from_numpy(noise_vector)
73
+ class_vector = torch.from_numpy(class_vector)
74
+
75
+ # If you have a GPU, put everything on cuda
76
+ noise_vector = noise_vector.to('cuda')
77
+ class_vector = class_vector.to('cuda')
78
+ model.to('cuda')
79
+
80
+ # Generate an image
81
+ with torch.no_grad():
82
+ output = model(noise_vector, class_vector, truncation)
83
+
84
+ # If you have a GPU put back on CPU
85
+ output = output.to('cpu')
86
+
87
+ # If you have a sixtel compatible terminal you can display the images in the terminal
88
+ # (see https://github.com/saitoha/libsixel for details)
89
+ display_in_terminal(output)
90
+
91
+ # Save results as png images
92
+ save_as_images(output)
93
+ ```
94
+
95
+ ![output_0](assets/output_0.png)
96
+ ![output_1](assets/output_1.png)
97
+ ![output_2](assets/output_2.png)
98
+
99
+ ## Doc
100
+
101
+ ### Loading DeepMind's pre-trained weights
102
+
103
+ To load one of DeepMind's pre-trained models, instantiate a `BigGAN` model with `from_pretrained()` as:
104
+
105
+ ```python
106
+ model = BigGAN.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
107
+ ```
108
+
109
+ where
110
+
111
+ - `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
112
+
113
+ - the shortcut name of a Google AI's or OpenAI's pre-trained model selected in the list:
114
+
115
+ - `biggan-deep-128`: 12-layer, 768-hidden, 12-heads, 110M parameters
116
+ - `biggan-deep-256`: 24-layer, 1024-hidden, 16-heads, 340M parameters
117
+ - `biggan-deep-512`: 12-layer, 768-hidden, 12-heads , 110M parameters
118
+
119
+ - a path or url to a pretrained model archive containing:
120
+
121
+ - `config.json`: a configuration file for the model, and
122
+ - `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BigGAN` (saved with the usual `torch.save()`).
123
+
124
+ If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_biggan/model.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_biggan/`).
125
+ - `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights.
126
+
127
+ ### Configuration
128
+
129
+ `BigGANConfig` is a class to store and load BigGAN configurations. It's defined in [`config.py`](./pytorch_pretrained_biggan/config.py).
130
+
131
+ Here are some details on the attributes:
132
+
133
+ - `output_dim`: output resolution of the GAN (128, 256 or 512) for the pre-trained models,
134
+ - `z_dim`: size of the noise vector (128 for the pre-trained models).
135
+ - `class_embed_dim`: size of the class embedding vectors (128 for the pre-trained models).
136
+ - `channel_width`: size of each channel (128 for the pre-trained models).
137
+ - `num_classes`: number of classes in the training dataset, like imagenet (1000 for the pre-trained models).
138
+ - `layers`: A list of layers definition. Each definition for a layer is a triple of [up-sample in the layer ? (bool), number of input channels (int), number of output channels (int)]
139
+ - `attention_layer_position`: Position of the self-attention layer in the layer hierarchy (8 for the pre-trained models).
140
+ - `eps`: epsilon value to use for spectral and batch normalization layers (1e-4 for the pre-trained models).
141
+ - `n_stats`: number of pre-computed statistics for the batch normalization layers associated to various truncation values between 0 and 1 (51 for the pre-trained models).
142
+
143
+ ### Model
144
+
145
+ `BigGAN` is a PyTorch model (`torch.nn.Module`) of BigGAN defined in [`model.py`](./pytorch_pretrained_biggan/model.py). This model comprises the class embeddings (a linear layer) and the generator with a series of convolutions and conditional batch norms. The discriminator is currently not implemented since pre-trained weights have not been released for it.
146
+
147
+ The inputs and output are **identical to the TensorFlow model inputs and outputs**.
148
+
149
+ We detail them here.
150
+
151
+ `BigGAN` takes as *inputs*:
152
+
153
+ - `z`: a torch.FloatTensor of shape [batch_size, config.z_dim] with noise sampled from a truncated normal distribution, and
154
+ - `class_label`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
155
+ - `truncation`: a float between 0 (not comprised) and 1. The truncation of the truncated normal used for creating the noise vector. This truncation value is used to selecte between a set of pre-computed statistics (means and variances) for the batch norm layers.
156
+
157
+ `BigGAN` *outputs* an array of shape [batch_size, 3, resolution, resolution] where resolution is 128, 256 or 512 depending of the model:
158
+
159
+ ### Utilities: Images, Noise, Imagenet classes
160
+
161
+ We provide a few utility method to use the model. They are defined in [`utils.py`](./pytorch_pretrained_biggan/utils.py).
162
+
163
+ Here are some details on these methods:
164
+
165
+ - `truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None)`:
166
+
167
+ Create a truncated noise vector.
168
+ - Params:
169
+ - batch_size: batch size.
170
+ - dim_z: dimension of z
171
+ - truncation: truncation value to use
172
+ - seed: seed for the random generator
173
+ - Output:
174
+ array of shape (batch_size, dim_z)
175
+
176
+ - `convert_to_images(obj)`:
177
+
178
+ Convert an output tensor from BigGAN in a list of images.
179
+ - Params:
180
+ - obj: tensor or numpy array of shape (batch_size, channels, height, width)
181
+ - Output:
182
+ - list of Pillow Images of size (height, width)
183
+
184
+ - `save_as_images(obj, file_name='output')`:
185
+
186
+ Convert and save an output tensor from BigGAN in a list of saved images.
187
+ - Params:
188
+ - obj: tensor or numpy array of shape (batch_size, channels, height, width)
189
+ - file_name: path and beggingin of filename to save.
190
+ Images will be saved as `file_name_{image_number}.png`
191
+
192
+ - `display_in_terminal(obj)`:
193
+
194
+ Convert and display an output tensor from BigGAN in the terminal. This function use `libsixel` and will only work in a libsixel-compatible terminal. Please refer to https://github.com/saitoha/libsixel for more details.
195
+ - Params:
196
+ - obj: tensor or numpy array of shape (batch_size, channels, height, width)
197
+ - file_name: path and beggingin of filename to save.
198
+ Images will be saved as `file_name_{image_number}.png`
199
+
200
+ - `one_hot_from_int(int_or_list, batch_size=1)`:
201
+
202
+ Create a one-hot vector from a class index or a list of class indices.
203
+ - Params:
204
+ - int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
205
+ - batch_size: batch size.
206
+ - If int_or_list is an int create a batch of identical classes.
207
+ - If int_or_list is a list, we should have `len(int_or_list) == batch_size`
208
+ - Output:
209
+ - array of shape (batch_size, 1000)
210
+
211
+ - `one_hot_from_names(class_name, batch_size=1)`:
212
+
213
+ Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...). We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one. If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
214
+ - Params:
215
+ - class_name: string containing the name of an imagenet object.
216
+ - Output:
217
+ - array of shape (batch_size, 1000)
218
+
219
+ ## Download and conversion scripts
220
+
221
+ Scripts to download and convert the TensorFlow models from TensorFlow Hub are provided in [./scripts](./scripts/).
222
+
223
+ The scripts can be used directly as:
224
+ ```bash
225
+ ./scripts/download_tf_hub_models.sh
226
+ ./scripts/convert_tf_hub_models.sh
227
+ ```
models/biggan/pytorch_biggan/full_requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tensorflow
2
+ tensorflow-hub
3
+ Pillow
4
+ nltk
5
+ libsixel-python
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .config import BigGANConfig
2
+ from .model import BigGAN
3
+ from .file_utils import PYTORCH_PRETRAINED_BIGGAN_CACHE, cached_path
4
+ from .utils import (truncated_noise_sample, save_as_images,
5
+ convert_to_images, display_in_terminal,
6
+ one_hot_from_int, one_hot_from_names)
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ BigGAN config.
4
+ """
5
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
6
+
7
+ import copy
8
+ import json
9
+
10
+ class BigGANConfig(object):
11
+ """ Configuration class to store the configuration of a `BigGAN`.
12
+ Defaults are for the 128x128 model.
13
+ layers tuple are (up-sample in the layer ?, input channels, output channels)
14
+ """
15
+ def __init__(self,
16
+ output_dim=128,
17
+ z_dim=128,
18
+ class_embed_dim=128,
19
+ channel_width=128,
20
+ num_classes=1000,
21
+ layers=[(False, 16, 16),
22
+ (True, 16, 16),
23
+ (False, 16, 16),
24
+ (True, 16, 8),
25
+ (False, 8, 8),
26
+ (True, 8, 4),
27
+ (False, 4, 4),
28
+ (True, 4, 2),
29
+ (False, 2, 2),
30
+ (True, 2, 1)],
31
+ attention_layer_position=8,
32
+ eps=1e-4,
33
+ n_stats=51):
34
+ """Constructs BigGANConfig. """
35
+ self.output_dim = output_dim
36
+ self.z_dim = z_dim
37
+ self.class_embed_dim = class_embed_dim
38
+ self.channel_width = channel_width
39
+ self.num_classes = num_classes
40
+ self.layers = layers
41
+ self.attention_layer_position = attention_layer_position
42
+ self.eps = eps
43
+ self.n_stats = n_stats
44
+
45
+ @classmethod
46
+ def from_dict(cls, json_object):
47
+ """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
48
+ config = BigGANConfig()
49
+ for key, value in json_object.items():
50
+ config.__dict__[key] = value
51
+ return config
52
+
53
+ @classmethod
54
+ def from_json_file(cls, json_file):
55
+ """Constructs a `BigGANConfig` from a json file of parameters."""
56
+ with open(json_file, "r", encoding='utf-8') as reader:
57
+ text = reader.read()
58
+ return cls.from_dict(json.loads(text))
59
+
60
+ def __repr__(self):
61
+ return str(self.to_json_string())
62
+
63
+ def to_dict(self):
64
+ """Serializes this instance to a Python dictionary."""
65
+ output = copy.deepcopy(self.__dict__)
66
+ return output
67
+
68
+ def to_json_string(self):
69
+ """Serializes this instance to a JSON string."""
70
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/convert_tf_to_pytorch.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ Convert a TF Hub model for BigGAN in a PT one.
4
+ """
5
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
6
+
7
+ from itertools import chain
8
+
9
+ import os
10
+ import argparse
11
+ import logging
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.nn.functional import normalize
17
+
18
+ from .model import BigGAN, WEIGHTS_NAME, CONFIG_NAME
19
+ from .config import BigGANConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def extract_batch_norm_stats(tf_model_path, batch_norm_stats_path=None):
25
+ try:
26
+ import numpy as np
27
+ import tensorflow as tf
28
+ import tensorflow_hub as hub
29
+ except ImportError:
30
+ raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow and TF Hub to be installed. "
31
+ "Please see https://www.tensorflow.org/install/ for installation instructions for TensorFlow. "
32
+ "And see https://github.com/tensorflow/hub for installing Hub. "
33
+ "Probably pip install tensorflow tensorflow-hub")
34
+ tf.reset_default_graph()
35
+ logger.info('Loading BigGAN module from: {}'.format(tf_model_path))
36
+ module = hub.Module(tf_model_path)
37
+ inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
38
+ for k, v in module.get_input_info_dict().items()}
39
+ output = module(inputs)
40
+
41
+ initializer = tf.global_variables_initializer()
42
+ sess = tf.Session()
43
+ stacks = sum(((i*10 + 1, i*10 + 3, i*10 + 6, i*10 + 8) for i in range(50)), ())
44
+ numpy_stacks = []
45
+ for i in stacks:
46
+ logger.info("Retrieving module_apply_default/stack_{}".format(i))
47
+ try:
48
+ stack_var = tf.get_default_graph().get_tensor_by_name("module_apply_default/stack_%d:0" % i)
49
+ except KeyError:
50
+ break # We have all the stats
51
+ numpy_stacks.append(sess.run(stack_var))
52
+
53
+ if batch_norm_stats_path is not None:
54
+ torch.save(numpy_stacks, batch_norm_stats_path)
55
+ else:
56
+ return numpy_stacks
57
+
58
+
59
+ def build_tf_to_pytorch_map(model, config):
60
+ """ Build a map from TF variables to PyTorch modules. """
61
+ tf_to_pt_map = {}
62
+
63
+ # Embeddings and GenZ
64
+ tf_to_pt_map.update({'linear/w/ema_0.9999': model.embeddings.weight,
65
+ 'Generator/GenZ/G_linear/b/ema_0.9999': model.generator.gen_z.bias,
66
+ 'Generator/GenZ/G_linear/w/ema_0.9999': model.generator.gen_z.weight_orig,
67
+ 'Generator/GenZ/G_linear/u0': model.generator.gen_z.weight_u})
68
+
69
+ # GBlock blocks
70
+ model_layer_idx = 0
71
+ for i, (up, in_channels, out_channels) in enumerate(config.layers):
72
+ if i == config.attention_layer_position:
73
+ model_layer_idx += 1
74
+ layer_str = "Generator/GBlock_%d/" % i if i > 0 else "Generator/GBlock/"
75
+ layer_pnt = model.generator.layers[model_layer_idx]
76
+ for i in range(4): # Batchnorms
77
+ batch_str = layer_str + ("BatchNorm_%d/" % i if i > 0 else "BatchNorm/")
78
+ batch_pnt = getattr(layer_pnt, 'bn_%d' % i)
79
+ for name in ('offset', 'scale'):
80
+ sub_module_str = batch_str + name + "/"
81
+ sub_module_pnt = getattr(batch_pnt, name)
82
+ tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
83
+ sub_module_str + "u0": sub_module_pnt.weight_u})
84
+ for i in range(4): # Convolutions
85
+ conv_str = layer_str + "conv%d/" % i
86
+ conv_pnt = getattr(layer_pnt, 'conv_%d' % i)
87
+ tf_to_pt_map.update({conv_str + "b/ema_0.9999": conv_pnt.bias,
88
+ conv_str + "w/ema_0.9999": conv_pnt.weight_orig,
89
+ conv_str + "u0": conv_pnt.weight_u})
90
+ model_layer_idx += 1
91
+
92
+ # Attention block
93
+ layer_str = "Generator/attention/"
94
+ layer_pnt = model.generator.layers[config.attention_layer_position]
95
+ tf_to_pt_map.update({layer_str + "gamma/ema_0.9999": layer_pnt.gamma})
96
+ for pt_name, tf_name in zip(['snconv1x1_g', 'snconv1x1_o_conv', 'snconv1x1_phi', 'snconv1x1_theta'],
97
+ ['g/', 'o_conv/', 'phi/', 'theta/']):
98
+ sub_module_str = layer_str + tf_name
99
+ sub_module_pnt = getattr(layer_pnt, pt_name)
100
+ tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
101
+ sub_module_str + "u0": sub_module_pnt.weight_u})
102
+
103
+ # final batch norm and conv to rgb
104
+ layer_str = "Generator/BatchNorm/"
105
+ layer_pnt = model.generator.bn
106
+ tf_to_pt_map.update({layer_str + "offset/ema_0.9999": layer_pnt.bias,
107
+ layer_str + "scale/ema_0.9999": layer_pnt.weight})
108
+ layer_str = "Generator/conv_to_rgb/"
109
+ layer_pnt = model.generator.conv_to_rgb
110
+ tf_to_pt_map.update({layer_str + "b/ema_0.9999": layer_pnt.bias,
111
+ layer_str + "w/ema_0.9999": layer_pnt.weight_orig,
112
+ layer_str + "u0": layer_pnt.weight_u})
113
+ return tf_to_pt_map
114
+
115
+
116
+ def load_tf_weights_in_biggan(model, config, tf_model_path, batch_norm_stats_path=None):
117
+ """ Load tf checkpoints and standing statistics in a pytorch model
118
+ """
119
+ try:
120
+ import numpy as np
121
+ import tensorflow as tf
122
+ except ImportError:
123
+ raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
124
+ "https://www.tensorflow.org/install/ for installation instructions.")
125
+ # Load weights from TF model
126
+ checkpoint_path = tf_model_path + "/variables/variables"
127
+ init_vars = tf.train.list_variables(checkpoint_path)
128
+ from pprint import pprint
129
+ pprint(init_vars)
130
+
131
+ # Extract batch norm statistics from model if needed
132
+ if batch_norm_stats_path:
133
+ stats = torch.load(batch_norm_stats_path)
134
+ else:
135
+ logger.info("Extracting batch norm stats")
136
+ stats = extract_batch_norm_stats(tf_model_path)
137
+
138
+ # Build TF to PyTorch weights loading map
139
+ tf_to_pt_map = build_tf_to_pytorch_map(model, config)
140
+
141
+ tf_weights = {}
142
+ for name in tf_to_pt_map.keys():
143
+ array = tf.train.load_variable(checkpoint_path, name)
144
+ tf_weights[name] = array
145
+ # logger.info("Loading TF weight {} with shape {}".format(name, array.shape))
146
+
147
+ # Load parameters
148
+ with torch.no_grad():
149
+ pt_params_pnt = set()
150
+ for name, pointer in tf_to_pt_map.items():
151
+ array = tf_weights[name]
152
+ if pointer.dim() == 1:
153
+ if pointer.dim() < array.ndim:
154
+ array = np.squeeze(array)
155
+ elif pointer.dim() == 2: # Weights
156
+ array = np.transpose(array)
157
+ elif pointer.dim() == 4: # Convolutions
158
+ array = np.transpose(array, (3, 2, 0, 1))
159
+ else:
160
+ raise "Wrong dimensions to adjust: " + str((pointer.shape, array.shape))
161
+ if pointer.shape != array.shape:
162
+ raise ValueError("Wrong dimensions: " + str((pointer.shape, array.shape)))
163
+ logger.info("Initialize PyTorch weight {} with shape {}".format(name, pointer.shape))
164
+ pointer.data = torch.from_numpy(array) if isinstance(array, np.ndarray) else torch.tensor(array)
165
+ tf_weights.pop(name, None)
166
+ pt_params_pnt.add(pointer.data_ptr())
167
+
168
+ # Prepare SpectralNorm buffers by running one step of Spectral Norm (no need to train the model):
169
+ for module in model.modules():
170
+ for n, buffer in module.named_buffers():
171
+ if n == 'weight_v':
172
+ weight_mat = module.weight_orig
173
+ weight_mat = weight_mat.reshape(weight_mat.size(0), -1)
174
+ u = module.weight_u
175
+
176
+ v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=config.eps)
177
+ buffer.data = v
178
+ pt_params_pnt.add(buffer.data_ptr())
179
+
180
+ u = normalize(torch.mv(weight_mat, v), dim=0, eps=config.eps)
181
+ module.weight_u.data = u
182
+ pt_params_pnt.add(module.weight_u.data_ptr())
183
+
184
+ # Load batch norm statistics
185
+ index = 0
186
+ for layer in model.generator.layers:
187
+ if not hasattr(layer, 'bn_0'):
188
+ continue
189
+ for i in range(4): # Batchnorms
190
+ bn_pointer = getattr(layer, 'bn_%d' % i)
191
+ pointer = bn_pointer.running_means
192
+ if pointer.shape != stats[index].shape:
193
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
194
+ pointer.data = torch.from_numpy(stats[index])
195
+ pt_params_pnt.add(pointer.data_ptr())
196
+
197
+ pointer = bn_pointer.running_vars
198
+ if pointer.shape != stats[index+1].shape:
199
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
200
+ pointer.data = torch.from_numpy(stats[index+1])
201
+ pt_params_pnt.add(pointer.data_ptr())
202
+
203
+ index += 2
204
+
205
+ bn_pointer = model.generator.bn
206
+ pointer = bn_pointer.running_means
207
+ if pointer.shape != stats[index].shape:
208
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
209
+ pointer.data = torch.from_numpy(stats[index])
210
+ pt_params_pnt.add(pointer.data_ptr())
211
+
212
+ pointer = bn_pointer.running_vars
213
+ if pointer.shape != stats[index+1].shape:
214
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
215
+ pointer.data = torch.from_numpy(stats[index+1])
216
+ pt_params_pnt.add(pointer.data_ptr())
217
+
218
+ remaining_params = list(n for n, t in chain(model.named_parameters(), model.named_buffers()) \
219
+ if t.data_ptr() not in pt_params_pnt)
220
+
221
+ logger.info("TF Weights not copied to PyTorch model: {} -".format(', '.join(tf_weights.keys())))
222
+ logger.info("Remanining parameters/buffers from PyTorch model: {} -".format(', '.join(remaining_params)))
223
+
224
+ return model
225
+
226
+
227
+ BigGAN128 = BigGANConfig(output_dim=128, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
228
+ layers=[(False, 16, 16),
229
+ (True, 16, 16),
230
+ (False, 16, 16),
231
+ (True, 16, 8),
232
+ (False, 8, 8),
233
+ (True, 8, 4),
234
+ (False, 4, 4),
235
+ (True, 4, 2),
236
+ (False, 2, 2),
237
+ (True, 2, 1)],
238
+ attention_layer_position=8, eps=1e-4, n_stats=51)
239
+
240
+ BigGAN256 = BigGANConfig(output_dim=256, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
241
+ layers=[(False, 16, 16),
242
+ (True, 16, 16),
243
+ (False, 16, 16),
244
+ (True, 16, 8),
245
+ (False, 8, 8),
246
+ (True, 8, 8),
247
+ (False, 8, 8),
248
+ (True, 8, 4),
249
+ (False, 4, 4),
250
+ (True, 4, 2),
251
+ (False, 2, 2),
252
+ (True, 2, 1)],
253
+ attention_layer_position=8, eps=1e-4, n_stats=51)
254
+
255
+ BigGAN512 = BigGANConfig(output_dim=512, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
256
+ layers=[(False, 16, 16),
257
+ (True, 16, 16),
258
+ (False, 16, 16),
259
+ (True, 16, 8),
260
+ (False, 8, 8),
261
+ (True, 8, 8),
262
+ (False, 8, 8),
263
+ (True, 8, 4),
264
+ (False, 4, 4),
265
+ (True, 4, 2),
266
+ (False, 2, 2),
267
+ (True, 2, 1),
268
+ (False, 1, 1),
269
+ (True, 1, 1)],
270
+ attention_layer_position=8, eps=1e-4, n_stats=51)
271
+
272
+
273
+ def main():
274
+ parser = argparse.ArgumentParser(description="Convert a BigGAN TF Hub model in a PyTorch model")
275
+ parser.add_argument("--model_type", type=str, default="", required=True,
276
+ help="BigGAN model type (128, 256, 512)")
277
+ parser.add_argument("--tf_model_path", type=str, default="", required=True,
278
+ help="Path of the downloaded TF Hub model")
279
+ parser.add_argument("--pt_save_path", type=str, default="",
280
+ help="Folder to save the PyTorch model (default: Folder of the TF Hub model)")
281
+ parser.add_argument("--batch_norm_stats_path", type=str, default="",
282
+ help="Path of previously extracted batch norm statistics")
283
+ args = parser.parse_args()
284
+
285
+ logging.basicConfig(level=logging.INFO)
286
+
287
+ if not args.pt_save_path:
288
+ args.pt_save_path = args.tf_model_path
289
+
290
+ if args.model_type == "128":
291
+ config = BigGAN128
292
+ elif args.model_type == "256":
293
+ config = BigGAN256
294
+ elif args.model_type == "512":
295
+ config = BigGAN512
296
+ else:
297
+ raise ValueError("model_type should be one of 128, 256 or 512")
298
+
299
+ model = BigGAN(config)
300
+ model = load_tf_weights_in_biggan(model, config, args.tf_model_path, args.batch_norm_stats_path)
301
+
302
+ model_save_path = os.path.join(args.pt_save_path, WEIGHTS_NAME)
303
+ config_save_path = os.path.join(args.pt_save_path, CONFIG_NAME)
304
+
305
+ logger.info("Save model dump to {}".format(model_save_path))
306
+ torch.save(model.state_dict(), model_save_path)
307
+ logger.info("Save configuration file to {}".format(config_save_path))
308
+ with open(config_save_path, "w", encoding="utf-8") as f:
309
+ f.write(config.to_json_string())
310
+
311
+ if __name__ == "__main__":
312
+ main()
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/file_utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import tempfile
13
+ from functools import wraps
14
+ from hashlib import sha256
15
+ import sys
16
+ from io import open
17
+
18
+ import boto3
19
+ import requests
20
+ from botocore.exceptions import ClientError
21
+ from tqdm import tqdm
22
+
23
+ try:
24
+ from urllib.parse import urlparse
25
+ except ImportError:
26
+ from urlparse import urlparse
27
+
28
+ try:
29
+ from pathlib import Path
30
+ PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
31
+ Path.home() / '.pytorch_pretrained_biggan'))
32
+ except (AttributeError, ImportError):
33
+ PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
34
+ os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan'))
35
+
36
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ def url_to_filename(url, etag=None):
40
+ """
41
+ Convert `url` into a hashed filename in a repeatable way.
42
+ If `etag` is specified, append its hash to the url's, delimited
43
+ by a period.
44
+ """
45
+ url_bytes = url.encode('utf-8')
46
+ url_hash = sha256(url_bytes)
47
+ filename = url_hash.hexdigest()
48
+
49
+ if etag:
50
+ etag_bytes = etag.encode('utf-8')
51
+ etag_hash = sha256(etag_bytes)
52
+ filename += '.' + etag_hash.hexdigest()
53
+
54
+ return filename
55
+
56
+
57
+ def filename_to_url(filename, cache_dir=None):
58
+ """
59
+ Return the url and etag (which may be ``None``) stored for `filename`.
60
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61
+ """
62
+ if cache_dir is None:
63
+ cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
64
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65
+ cache_dir = str(cache_dir)
66
+
67
+ cache_path = os.path.join(cache_dir, filename)
68
+ if not os.path.exists(cache_path):
69
+ raise EnvironmentError("file {} not found".format(cache_path))
70
+
71
+ meta_path = cache_path + '.json'
72
+ if not os.path.exists(meta_path):
73
+ raise EnvironmentError("file {} not found".format(meta_path))
74
+
75
+ with open(meta_path, encoding="utf-8") as meta_file:
76
+ metadata = json.load(meta_file)
77
+ url = metadata['url']
78
+ etag = metadata['etag']
79
+
80
+ return url, etag
81
+
82
+
83
+ def cached_path(url_or_filename, cache_dir=None):
84
+ """
85
+ Given something that might be a URL (or might be a local path),
86
+ determine which. If it's a URL, download the file and cache it, and
87
+ return the path to the cached file. If it's already a local path,
88
+ make sure the file exists and then return the path.
89
+ """
90
+ if cache_dir is None:
91
+ cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
92
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93
+ url_or_filename = str(url_or_filename)
94
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95
+ cache_dir = str(cache_dir)
96
+
97
+ parsed = urlparse(url_or_filename)
98
+
99
+ if parsed.scheme in ('http', 'https', 's3'):
100
+ # URL, so get it from the cache (downloading if necessary)
101
+ return get_from_cache(url_or_filename, cache_dir)
102
+ elif os.path.exists(url_or_filename):
103
+ # File, and it exists.
104
+ return url_or_filename
105
+ elif parsed.scheme == '':
106
+ # File, but it doesn't exist.
107
+ raise EnvironmentError("file {} not found".format(url_or_filename))
108
+ else:
109
+ # Something unknown
110
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111
+
112
+
113
+ def split_s3_path(url):
114
+ """Split a full s3 path into the bucket name and path."""
115
+ parsed = urlparse(url)
116
+ if not parsed.netloc or not parsed.path:
117
+ raise ValueError("bad s3 path {}".format(url))
118
+ bucket_name = parsed.netloc
119
+ s3_path = parsed.path
120
+ # Remove '/' at beginning of path.
121
+ if s3_path.startswith("/"):
122
+ s3_path = s3_path[1:]
123
+ return bucket_name, s3_path
124
+
125
+
126
+ def s3_request(func):
127
+ """
128
+ Wrapper function for s3 requests in order to create more helpful error
129
+ messages.
130
+ """
131
+
132
+ @wraps(func)
133
+ def wrapper(url, *args, **kwargs):
134
+ try:
135
+ return func(url, *args, **kwargs)
136
+ except ClientError as exc:
137
+ if int(exc.response["Error"]["Code"]) == 404:
138
+ raise EnvironmentError("file {} not found".format(url))
139
+ else:
140
+ raise
141
+
142
+ return wrapper
143
+
144
+
145
+ @s3_request
146
+ def s3_etag(url):
147
+ """Check ETag on S3 object."""
148
+ s3_resource = boto3.resource("s3")
149
+ bucket_name, s3_path = split_s3_path(url)
150
+ s3_object = s3_resource.Object(bucket_name, s3_path)
151
+ return s3_object.e_tag
152
+
153
+
154
+ @s3_request
155
+ def s3_get(url, temp_file):
156
+ """Pull a file directly from S3."""
157
+ s3_resource = boto3.resource("s3")
158
+ bucket_name, s3_path = split_s3_path(url)
159
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160
+
161
+
162
+ def http_get(url, temp_file):
163
+ req = requests.get(url, stream=True)
164
+ content_length = req.headers.get('Content-Length')
165
+ total = int(content_length) if content_length is not None else None
166
+ progress = tqdm(unit="B", total=total)
167
+ for chunk in req.iter_content(chunk_size=1024):
168
+ if chunk: # filter out keep-alive new chunks
169
+ progress.update(len(chunk))
170
+ temp_file.write(chunk)
171
+ progress.close()
172
+
173
+
174
+ def get_from_cache(url, cache_dir=None):
175
+ """
176
+ Given a URL, look for the corresponding dataset in the local cache.
177
+ If it's not there, download it. Then return the path to the cached file.
178
+ """
179
+ if cache_dir is None:
180
+ cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
181
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182
+ cache_dir = str(cache_dir)
183
+
184
+ if not os.path.exists(cache_dir):
185
+ os.makedirs(cache_dir)
186
+
187
+ # Get eTag to add to filename, if it exists.
188
+ if url.startswith("s3://"):
189
+ etag = s3_etag(url)
190
+ else:
191
+ response = requests.head(url, allow_redirects=True)
192
+ if response.status_code != 200:
193
+ raise IOError("HEAD request failed for url {} with status code {}"
194
+ .format(url, response.status_code))
195
+ etag = response.headers.get("ETag")
196
+
197
+ filename = url_to_filename(url, etag)
198
+
199
+ # get cache path to put the file
200
+ cache_path = os.path.join(cache_dir, filename)
201
+
202
+ if not os.path.exists(cache_path):
203
+ # Download to temporary file, then copy to cache dir once finished.
204
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
205
+ with tempfile.NamedTemporaryFile() as temp_file:
206
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207
+
208
+ # GET file object
209
+ if url.startswith("s3://"):
210
+ s3_get(url, temp_file)
211
+ else:
212
+ http_get(url, temp_file)
213
+
214
+ # we are copying the file before closing it, so flush to avoid truncation
215
+ temp_file.flush()
216
+ # shutil.copyfileobj() starts at the current position, so go to the start
217
+ temp_file.seek(0)
218
+
219
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220
+ with open(cache_path, 'wb') as cache_file:
221
+ shutil.copyfileobj(temp_file, cache_file)
222
+
223
+ logger.info("creating metadata file for %s", cache_path)
224
+ meta = {'url': url, 'etag': etag}
225
+ meta_path = cache_path + '.json'
226
+ with open(meta_path, 'w', encoding="utf-8") as meta_file:
227
+ json.dump(meta, meta_file)
228
+
229
+ logger.info("removing temp file %s", temp_file.name)
230
+
231
+ return cache_path
232
+
233
+
234
+ def read_set_from_file(filename):
235
+ '''
236
+ Extract a de-duped collection (set) of text from a file.
237
+ Expected file format is one item per line.
238
+ '''
239
+ collection = set()
240
+ with open(filename, 'r', encoding='utf-8') as file_:
241
+ for line in file_:
242
+ collection.add(line.rstrip())
243
+ return collection
244
+
245
+
246
+ def get_file_extension(path, dot=True, lower=True):
247
+ ext = os.path.splitext(path)[1]
248
+ ext = ext if dot else ext[1:]
249
+ return ext.lower() if lower else ext
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/model.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """ BigGAN PyTorch model.
3
+ From "Large Scale GAN Training for High Fidelity Natural Image Synthesis"
4
+ By Andrew Brocky, Jeff Donahuey and Karen Simonyan.
5
+ https://openreview.net/forum?id=B1xsqj09Fm
6
+
7
+ PyTorch version implemented from the computational graph of the TF Hub module for BigGAN.
8
+ Some part of the code are adapted from https://github.com/brain-research/self-attention-gan
9
+
10
+ This version only comprises the generator (since the discriminator's weights are not released).
11
+ This version only comprises the "deep" version of BigGAN (see publication).
12
+
13
+ Modified by Erik Härkönen:
14
+ * Added support for per-layer latent vectors
15
+ """
16
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
17
+
18
+ import os
19
+ import logging
20
+ import math
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from .config import BigGANConfig
28
+ from .file_utils import cached_path
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ PRETRAINED_MODEL_ARCHIVE_MAP = {
33
+ 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
34
+ 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
35
+ 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
36
+ }
37
+
38
+ PRETRAINED_CONFIG_ARCHIVE_MAP = {
39
+ 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
40
+ 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
41
+ 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
42
+ }
43
+
44
+ WEIGHTS_NAME = 'pytorch_model.bin'
45
+ CONFIG_NAME = 'config.json'
46
+
47
+
48
+ def snconv2d(eps=1e-12, **kwargs):
49
+ return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
50
+
51
+ def snlinear(eps=1e-12, **kwargs):
52
+ return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
53
+
54
+ def sn_embedding(eps=1e-12, **kwargs):
55
+ return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
56
+
57
+ class SelfAttn(nn.Module):
58
+ """ Self attention Layer"""
59
+ def __init__(self, in_channels, eps=1e-12):
60
+ super(SelfAttn, self).__init__()
61
+ self.in_channels = in_channels
62
+ self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
63
+ kernel_size=1, bias=False, eps=eps)
64
+ self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
65
+ kernel_size=1, bias=False, eps=eps)
66
+ self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
67
+ kernel_size=1, bias=False, eps=eps)
68
+ self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
69
+ kernel_size=1, bias=False, eps=eps)
70
+ self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
71
+ self.softmax = nn.Softmax(dim=-1)
72
+ self.gamma = nn.Parameter(torch.zeros(1))
73
+
74
+ def forward(self, x):
75
+ _, ch, h, w = x.size()
76
+ # Theta path
77
+ theta = self.snconv1x1_theta(x)
78
+ theta = theta.view(-1, ch//8, h*w)
79
+ # Phi path
80
+ phi = self.snconv1x1_phi(x)
81
+ phi = self.maxpool(phi)
82
+ phi = phi.view(-1, ch//8, h*w//4)
83
+ # Attn map
84
+ attn = torch.bmm(theta.permute(0, 2, 1), phi)
85
+ attn = self.softmax(attn)
86
+ # g path
87
+ g = self.snconv1x1_g(x)
88
+ g = self.maxpool(g)
89
+ g = g.view(-1, ch//2, h*w//4)
90
+ # Attn_g - o_conv
91
+ attn_g = torch.bmm(g, attn.permute(0, 2, 1))
92
+ attn_g = attn_g.view(-1, ch//2, h, w)
93
+ attn_g = self.snconv1x1_o_conv(attn_g)
94
+ # Out
95
+ out = x + self.gamma*attn_g
96
+ return out
97
+
98
+
99
+ class BigGANBatchNorm(nn.Module):
100
+ """ This is a batch norm module that can handle conditional input and can be provided with pre-computed
101
+ activation means and variances for various truncation parameters.
102
+
103
+ We cannot just rely on torch.batch_norm since it cannot handle
104
+ batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
105
+ If you want to train this model you should add running means and variance computation logic.
106
+ """
107
+ def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
108
+ super(BigGANBatchNorm, self).__init__()
109
+ self.num_features = num_features
110
+ self.eps = eps
111
+ self.conditional = conditional
112
+
113
+ # We use pre-computed statistics for n_stats values of truncation between 0 and 1
114
+ self.register_buffer('running_means', torch.zeros(n_stats, num_features))
115
+ self.register_buffer('running_vars', torch.ones(n_stats, num_features))
116
+ self.step_size = 1.0 / (n_stats - 1)
117
+
118
+ if conditional:
119
+ assert condition_vector_dim is not None
120
+ self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
121
+ self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
122
+ else:
123
+ self.weight = torch.nn.Parameter(torch.Tensor(num_features))
124
+ self.bias = torch.nn.Parameter(torch.Tensor(num_features))
125
+
126
+ def forward(self, x, truncation, condition_vector=None):
127
+ # Retreive pre-computed statistics associated to this truncation
128
+ coef, start_idx = math.modf(truncation / self.step_size)
129
+ start_idx = int(start_idx)
130
+ if coef != 0.0: # Interpolate
131
+ running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
132
+ running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
133
+ else:
134
+ running_mean = self.running_means[start_idx]
135
+ running_var = self.running_vars[start_idx]
136
+
137
+ if self.conditional:
138
+ running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
139
+ running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
140
+
141
+ weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
142
+ bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
143
+
144
+ out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
145
+ else:
146
+ out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
147
+ training=False, momentum=0.0, eps=self.eps)
148
+
149
+ return out
150
+
151
+
152
+ class GenBlock(nn.Module):
153
+ def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
154
+ n_stats=51, eps=1e-12):
155
+ super(GenBlock, self).__init__()
156
+ self.up_sample = up_sample
157
+ self.drop_channels = (in_size != out_size)
158
+ middle_size = in_size // reduction_factor
159
+
160
+ self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
161
+ self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
162
+
163
+ self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
164
+ self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
165
+
166
+ self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
167
+ self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
168
+
169
+ self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
170
+ self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
171
+
172
+ self.relu = nn.ReLU()
173
+
174
+ def forward(self, x, cond_vector, truncation):
175
+ x0 = x
176
+
177
+ x = self.bn_0(x, truncation, cond_vector)
178
+ x = self.relu(x)
179
+ x = self.conv_0(x)
180
+
181
+ x = self.bn_1(x, truncation, cond_vector)
182
+ x = self.relu(x)
183
+ if self.up_sample:
184
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
185
+ x = self.conv_1(x)
186
+
187
+ x = self.bn_2(x, truncation, cond_vector)
188
+ x = self.relu(x)
189
+ x = self.conv_2(x)
190
+
191
+ x = self.bn_3(x, truncation, cond_vector)
192
+ x = self.relu(x)
193
+ x = self.conv_3(x)
194
+
195
+ if self.drop_channels:
196
+ new_channels = x0.shape[1] // 2
197
+ x0 = x0[:, :new_channels, ...]
198
+ if self.up_sample:
199
+ x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
200
+
201
+ out = x + x0
202
+ return out
203
+
204
+ class Generator(nn.Module):
205
+ def __init__(self, config):
206
+ super(Generator, self).__init__()
207
+ self.config = config
208
+ ch = config.channel_width
209
+ condition_vector_dim = config.z_dim * 2
210
+
211
+ self.gen_z = snlinear(in_features=condition_vector_dim,
212
+ out_features=4 * 4 * 16 * ch, eps=config.eps)
213
+
214
+ layers = []
215
+ for i, layer in enumerate(config.layers):
216
+ if i == config.attention_layer_position:
217
+ layers.append(SelfAttn(ch*layer[1], eps=config.eps))
218
+ layers.append(GenBlock(ch*layer[1],
219
+ ch*layer[2],
220
+ condition_vector_dim,
221
+ up_sample=layer[0],
222
+ n_stats=config.n_stats,
223
+ eps=config.eps))
224
+ self.layers = nn.ModuleList(layers)
225
+
226
+ self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
227
+ self.relu = nn.ReLU()
228
+ self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
229
+ self.tanh = nn.Tanh()
230
+
231
+ def forward(self, cond_vector, truncation):
232
+ z = self.gen_z(cond_vector[0])
233
+
234
+ # We use this conversion step to be able to use TF weights:
235
+ # TF convention on shape is [batch, height, width, channels]
236
+ # PT convention on shape is [batch, channels, height, width]
237
+ z = z.view(-1, 4, 4, 16 * self.config.channel_width)
238
+ z = z.permute(0, 3, 1, 2).contiguous()
239
+
240
+ cond_idx = 1
241
+ for i, layer in enumerate(self.layers):
242
+ if isinstance(layer, GenBlock):
243
+ z = layer(z, cond_vector[cond_idx], truncation)
244
+ cond_idx += 1
245
+ else:
246
+ z = layer(z)
247
+
248
+ z = self.bn(z, truncation)
249
+ z = self.relu(z)
250
+ z = self.conv_to_rgb(z)
251
+ z = z[:, :3, ...]
252
+ z = self.tanh(z)
253
+ return z
254
+
255
+ class BigGAN(nn.Module):
256
+ """BigGAN Generator."""
257
+
258
+ @classmethod
259
+ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
260
+ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
261
+ model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
262
+ config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
263
+ else:
264
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
265
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
266
+
267
+ try:
268
+ resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
269
+ resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
270
+ except EnvironmentError:
271
+ logger.error("Wrong model name, should be a valid path to a folder containing "
272
+ "a {} file and a {} file or a model name in {}".format(
273
+ WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
274
+ raise
275
+
276
+ logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
277
+
278
+ # Load config
279
+ config = BigGANConfig.from_json_file(resolved_config_file)
280
+ logger.info("Model config {}".format(config))
281
+
282
+ # Instantiate model.
283
+ model = cls(config, *inputs, **kwargs)
284
+ state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
285
+ model.load_state_dict(state_dict, strict=False)
286
+ return model
287
+
288
+ def __init__(self, config):
289
+ super(BigGAN, self).__init__()
290
+ self.config = config
291
+ self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
292
+ self.generator = Generator(config)
293
+ self.n_latents = len(config.layers) + 1 # one for gen_z + one per layer
294
+
295
+ def forward(self, z, class_label, truncation):
296
+ assert 0 < truncation <= 1
297
+
298
+ if not isinstance(z, list):
299
+ z = self.n_latents*[z]
300
+
301
+ if isinstance(class_label, list):
302
+ embed = [self.embeddings(l) for l in class_label]
303
+ else:
304
+ embed = self.n_latents*[self.embeddings(class_label)]
305
+
306
+ assert len(z) == self.n_latents, f'Expected {self.n_latents} latents, got {len(z)}'
307
+ assert len(embed) == self.n_latents, f'Expected {self.n_latents} class vectors, got {len(class_label)}'
308
+
309
+ cond_vectors = [torch.cat((z, e), dim=1) for (z, e) in zip(z, embed)]
310
+ z = self.generator(cond_vectors, truncation)
311
+ return z
312
+
313
+
314
+ if __name__ == "__main__":
315
+ import PIL
316
+ from .utils import truncated_noise_sample, save_as_images, one_hot_from_names
317
+ from .convert_tf_to_pytorch import load_tf_weights_in_biggan
318
+
319
+ load_cache = False
320
+ cache_path = './saved_model.pt'
321
+ config = BigGANConfig()
322
+ model = BigGAN(config)
323
+ if not load_cache:
324
+ model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin')
325
+ torch.save(model.state_dict(), cache_path)
326
+ else:
327
+ model.load_state_dict(torch.load(cache_path))
328
+
329
+ model.eval()
330
+
331
+ truncation = 0.4
332
+ noise = truncated_noise_sample(batch_size=2, truncation=truncation)
333
+ label = one_hot_from_names('diver', batch_size=2)
334
+
335
+ # Tests
336
+ # noise = np.zeros((1, 128))
337
+ # label = [983]
338
+
339
+ noise = torch.tensor(noise, dtype=torch.float)
340
+ label = torch.tensor(label, dtype=torch.float)
341
+ with torch.no_grad():
342
+ outputs = model(noise, label, truncation)
343
+ print(outputs.shape)
344
+
345
+ save_as_images(outputs)
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """ BigGAN utilities to prepare truncated noise samples and convert/save/display output images.
3
+ Also comprise ImageNet utilities to prepare one hot input vectors for ImageNet classes.
4
+ We use Wordnet so you can just input a name in a string and automatically get a corresponding
5
+ imagenet class if it exists (or a hypo/hypernym exists in imagenet).
6
+ """
7
+ from __future__ import absolute_import, division, print_function, unicode_literals
8
+
9
+ import json
10
+ import logging
11
+ from io import BytesIO
12
+
13
+ import numpy as np
14
+ from scipy.stats import truncnorm
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ NUM_CLASSES = 1000
19
+
20
+
21
+ def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None):
22
+ """ Create a truncated noise vector.
23
+ Params:
24
+ batch_size: batch size.
25
+ dim_z: dimension of z
26
+ truncation: truncation value to use
27
+ seed: seed for the random generator
28
+ Output:
29
+ array of shape (batch_size, dim_z)
30
+ """
31
+ state = None if seed is None else np.random.RandomState(seed)
32
+ values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
33
+ return truncation * values
34
+
35
+
36
+ def convert_to_images(obj):
37
+ """ Convert an output tensor from BigGAN in a list of images.
38
+ Params:
39
+ obj: tensor or numpy array of shape (batch_size, channels, height, width)
40
+ Output:
41
+ list of Pillow Images of size (height, width)
42
+ """
43
+ try:
44
+ import PIL
45
+ except ImportError:
46
+ raise ImportError("Please install Pillow to use images: pip install Pillow")
47
+
48
+ if not isinstance(obj, np.ndarray):
49
+ obj = obj.detach().numpy()
50
+
51
+ obj = obj.transpose((0, 2, 3, 1))
52
+ obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)
53
+
54
+ img = []
55
+ for i, out in enumerate(obj):
56
+ out_array = np.asarray(np.uint8(out), dtype=np.uint8)
57
+ img.append(PIL.Image.fromarray(out_array))
58
+ return img
59
+
60
+
61
+ def save_as_images(obj, file_name='output'):
62
+ """ Convert and save an output tensor from BigGAN in a list of saved images.
63
+ Params:
64
+ obj: tensor or numpy array of shape (batch_size, channels, height, width)
65
+ file_name: path and beggingin of filename to save.
66
+ Images will be saved as `file_name_{image_number}.png`
67
+ """
68
+ img = convert_to_images(obj)
69
+
70
+ for i, out in enumerate(img):
71
+ current_file_name = file_name + '_%d.png' % i
72
+ logger.info("Saving image to {}".format(current_file_name))
73
+ out.save(current_file_name, 'png')
74
+
75
+
76
+ def display_in_terminal(obj):
77
+ """ Convert and display an output tensor from BigGAN in the terminal.
78
+ This function use `libsixel` and will only work in a libsixel-compatible terminal.
79
+ Please refer to https://github.com/saitoha/libsixel for more details.
80
+
81
+ Params:
82
+ obj: tensor or numpy array of shape (batch_size, channels, height, width)
83
+ file_name: path and beggingin of filename to save.
84
+ Images will be saved as `file_name_{image_number}.png`
85
+ """
86
+ try:
87
+ import PIL
88
+ from libsixel import (sixel_output_new, sixel_dither_new, sixel_dither_initialize,
89
+ sixel_dither_set_palette, sixel_dither_set_pixelformat,
90
+ sixel_dither_get, sixel_encode, sixel_dither_unref,
91
+ sixel_output_unref, SIXEL_PIXELFORMAT_RGBA8888,
92
+ SIXEL_PIXELFORMAT_RGB888, SIXEL_PIXELFORMAT_PAL8,
93
+ SIXEL_PIXELFORMAT_G8, SIXEL_PIXELFORMAT_G1)
94
+ except ImportError:
95
+ raise ImportError("Display in Terminal requires Pillow, libsixel "
96
+ "and a libsixel compatible terminal. "
97
+ "Please read info at https://github.com/saitoha/libsixel "
98
+ "and install with pip install Pillow libsixel-python")
99
+
100
+ s = BytesIO()
101
+
102
+ images = convert_to_images(obj)
103
+ widths, heights = zip(*(i.size for i in images))
104
+
105
+ output_width = sum(widths)
106
+ output_height = max(heights)
107
+
108
+ output_image = PIL.Image.new('RGB', (output_width, output_height))
109
+
110
+ x_offset = 0
111
+ for im in images:
112
+ output_image.paste(im, (x_offset,0))
113
+ x_offset += im.size[0]
114
+
115
+ try:
116
+ data = output_image.tobytes()
117
+ except NotImplementedError:
118
+ data = output_image.tostring()
119
+ output = sixel_output_new(lambda data, s: s.write(data), s)
120
+
121
+ try:
122
+ if output_image.mode == 'RGBA':
123
+ dither = sixel_dither_new(256)
124
+ sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGBA8888)
125
+ elif output_image.mode == 'RGB':
126
+ dither = sixel_dither_new(256)
127
+ sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGB888)
128
+ elif output_image.mode == 'P':
129
+ palette = output_image.getpalette()
130
+ dither = sixel_dither_new(256)
131
+ sixel_dither_set_palette(dither, palette)
132
+ sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_PAL8)
133
+ elif output_image.mode == 'L':
134
+ dither = sixel_dither_get(SIXEL_BUILTIN_G8)
135
+ sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G8)
136
+ elif output_image.mode == '1':
137
+ dither = sixel_dither_get(SIXEL_BUILTIN_G1)
138
+ sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G1)
139
+ else:
140
+ raise RuntimeError('unexpected output_image mode')
141
+ try:
142
+ sixel_encode(data, output_width, output_height, 1, dither, output)
143
+ print(s.getvalue().decode('ascii'))
144
+ finally:
145
+ sixel_dither_unref(dither)
146
+ finally:
147
+ sixel_output_unref(output)
148
+
149
+
150
+ def one_hot_from_int(int_or_list, batch_size=1):
151
+ """ Create a one-hot vector from a class index or a list of class indices.
152
+ Params:
153
+ int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
154
+ batch_size: batch size.
155
+ If int_or_list is an int create a batch of identical classes.
156
+ If int_or_list is a list, we should have `len(int_or_list) == batch_size`
157
+ Output:
158
+ array of shape (batch_size, 1000)
159
+ """
160
+ if isinstance(int_or_list, int):
161
+ int_or_list = [int_or_list]
162
+
163
+ if len(int_or_list) == 1 and batch_size > 1:
164
+ int_or_list = [int_or_list[0]] * batch_size
165
+
166
+ assert batch_size == len(int_or_list)
167
+
168
+ array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32)
169
+ for i, j in enumerate(int_or_list):
170
+ array[i, j] = 1.0
171
+ return array
172
+
173
+
174
+ def one_hot_from_names(class_name_or_list, batch_size=1):
175
+ """ Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...).
176
+ We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one.
177
+ If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
178
+
179
+ Params:
180
+ class_name_or_list: string containing the name of an imagenet object or a list of such strings (for a batch).
181
+ Output:
182
+ array of shape (batch_size, 1000)
183
+ """
184
+ try:
185
+ from nltk.corpus import wordnet as wn
186
+ except ImportError:
187
+ raise ImportError("You need to install nltk to use this function")
188
+
189
+ if not isinstance(class_name_or_list, (list, tuple)):
190
+ class_name_or_list = [class_name_or_list]
191
+ else:
192
+ batch_size = max(batch_size, len(class_name_or_list))
193
+
194
+ classes = []
195
+ for class_name in class_name_or_list:
196
+ class_name = class_name.replace(" ", "_")
197
+
198
+ original_synsets = wn.synsets(class_name)
199
+ original_synsets = list(filter(lambda s: s.pos() == 'n', original_synsets)) # keep only names
200
+ if not original_synsets:
201
+ return None
202
+
203
+ possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, original_synsets))
204
+ if possible_synsets:
205
+ classes.append(IMAGENET[possible_synsets[0].offset()])
206
+ else:
207
+ # try hypernyms and hyponyms
208
+ possible_synsets = sum([s.hypernyms() + s.hyponyms() for s in original_synsets], [])
209
+ possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, possible_synsets))
210
+ if possible_synsets:
211
+ classes.append(IMAGENET[possible_synsets[0].offset()])
212
+
213
+ return one_hot_from_int(classes, batch_size=batch_size)
214
+
215
+
216
+ IMAGENET = {1440764: 0, 1443537: 1, 1484850: 2, 1491361: 3, 1494475: 4, 1496331: 5, 1498041: 6, 1514668: 7, 1514859: 8, 1518878: 9, 1530575: 10, 1531178: 11, 1532829: 12, 1534433: 13, 1537544: 14, 1558993: 15, 1560419: 16, 1580077: 17, 1582220: 18, 1592084: 19, 1601694: 20, 1608432: 21, 1614925: 22, 1616318: 23, 1622779: 24, 1629819: 25, 1630670: 26, 1631663: 27, 1632458: 28, 1632777: 29, 1641577: 30, 1644373: 31, 1644900: 32, 1664065: 33, 1665541: 34, 1667114: 35, 1667778: 36, 1669191: 37, 1675722: 38, 1677366: 39, 1682714: 40, 1685808: 41, 1687978: 42, 1688243: 43, 1689811: 44, 1692333: 45, 1693334: 46, 1694178: 47, 1695060: 48, 1697457: 49, 1698640: 50, 1704323: 51, 1728572: 52, 1728920: 53, 1729322: 54, 1729977: 55, 1734418: 56, 1735189: 57, 1737021: 58, 1739381: 59, 1740131: 60, 1742172: 61, 1744401: 62, 1748264: 63, 1749939: 64, 1751748: 65, 1753488: 66, 1755581: 67, 1756291: 68, 1768244: 69, 1770081: 70, 1770393: 71, 1773157: 72, 1773549: 73, 1773797: 74, 1774384: 75, 1774750: 76, 1775062: 77, 1776313: 78, 1784675: 79, 1795545: 80, 1796340: 81, 1797886: 82, 1798484: 83, 1806143: 84, 1806567: 85, 1807496: 86, 1817953: 87, 1818515: 88, 1819313: 89, 1820546: 90, 1824575: 91, 1828970: 92, 1829413: 93, 1833805: 94, 1843065: 95, 1843383: 96, 1847000: 97, 1855032: 98, 1855672: 99, 1860187: 100, 1871265: 101, 1872401: 102, 1873310: 103, 1877812: 104, 1882714: 105, 1883070: 106, 1910747: 107, 1914609: 108, 1917289: 109, 1924916: 110, 1930112: 111, 1943899: 112, 1944390: 113, 1945685: 114, 1950731: 115, 1955084: 116, 1968897: 117, 1978287: 118, 1978455: 119, 1980166: 120, 1981276: 121, 1983481: 122, 1984695: 123, 1985128: 124, 1986214: 125, 1990800: 126, 2002556: 127, 2002724: 128, 2006656: 129, 2007558: 130, 2009229: 131, 2009912: 132, 2011460: 133, 2012849: 134, 2013706: 135, 2017213: 136, 2018207: 137, 2018795: 138, 2025239: 139, 2027492: 140, 2028035: 141, 2033041: 142, 2037110: 143, 2051845: 144, 2056570: 145, 2058221: 146, 2066245: 147, 2071294: 148, 2074367: 149, 2077923: 150, 2085620: 151, 2085782: 152, 2085936: 153, 2086079: 154, 2086240: 155, 2086646: 156, 2086910: 157, 2087046: 158, 2087394: 159, 2088094: 160, 2088238: 161, 2088364: 162, 2088466: 163, 2088632: 164, 2089078: 165, 2089867: 166, 2089973: 167, 2090379: 168, 2090622: 169, 2090721: 170, 2091032: 171, 2091134: 172, 2091244: 173, 2091467: 174, 2091635: 175, 2091831: 176, 2092002: 177, 2092339: 178, 2093256: 179, 2093428: 180, 2093647: 181, 2093754: 182, 2093859: 183, 2093991: 184, 2094114: 185, 2094258: 186, 2094433: 187, 2095314: 188, 2095570: 189, 2095889: 190, 2096051: 191, 2096177: 192, 2096294: 193, 2096437: 194, 2096585: 195, 2097047: 196, 2097130: 197, 2097209: 198, 2097298: 199, 2097474: 200, 2097658: 201, 2098105: 202, 2098286: 203, 2098413: 204, 2099267: 205, 2099429: 206, 2099601: 207, 2099712: 208, 2099849: 209, 2100236: 210, 2100583: 211, 2100735: 212, 2100877: 213, 2101006: 214, 2101388: 215, 2101556: 216, 2102040: 217, 2102177: 218, 2102318: 219, 2102480: 220, 2102973: 221, 2104029: 222, 2104365: 223, 2105056: 224, 2105162: 225, 2105251: 226, 2105412: 227, 2105505: 228, 2105641: 229, 2105855: 230, 2106030: 231, 2106166: 232, 2106382: 233, 2106550: 234, 2106662: 235, 2107142: 236, 2107312: 237, 2107574: 238, 2107683: 239, 2107908: 240, 2108000: 241, 2108089: 242, 2108422: 243, 2108551: 244, 2108915: 245, 2109047: 246, 2109525: 247, 2109961: 248, 2110063: 249, 2110185: 250, 2110341: 251, 2110627: 252, 2110806: 253, 2110958: 254, 2111129: 255, 2111277: 256, 2111500: 257, 2111889: 258, 2112018: 259, 2112137: 260, 2112350: 261, 2112706: 262, 2113023: 263, 2113186: 264, 2113624: 265, 2113712: 266, 2113799: 267, 2113978: 268, 2114367: 269, 2114548: 270, 2114712: 271, 2114855: 272, 2115641: 273, 2115913: 274, 2116738: 275, 2117135: 276, 2119022: 277, 2119789: 278, 2120079: 279, 2120505: 280, 2123045: 281, 2123159: 282, 2123394: 283, 2123597: 284, 2124075: 285, 2125311: 286, 2127052: 287, 2128385: 288, 2128757: 289, 2128925: 290, 2129165: 291, 2129604: 292, 2130308: 293, 2132136: 294, 2133161: 295, 2134084: 296, 2134418: 297, 2137549: 298, 2138441: 299, 2165105: 300, 2165456: 301, 2167151: 302, 2168699: 303, 2169497: 304, 2172182: 305, 2174001: 306, 2177972: 307, 2190166: 308, 2206856: 309, 2219486: 310, 2226429: 311, 2229544: 312, 2231487: 313, 2233338: 314, 2236044: 315, 2256656: 316, 2259212: 317, 2264363: 318, 2268443: 319, 2268853: 320, 2276258: 321, 2277742: 322, 2279972: 323, 2280649: 324, 2281406: 325, 2281787: 326, 2317335: 327, 2319095: 328, 2321529: 329, 2325366: 330, 2326432: 331, 2328150: 332, 2342885: 333, 2346627: 334, 2356798: 335, 2361337: 336, 2363005: 337, 2364673: 338, 2389026: 339, 2391049: 340, 2395406: 341, 2396427: 342, 2397096: 343, 2398521: 344, 2403003: 345, 2408429: 346, 2410509: 347, 2412080: 348, 2415577: 349, 2417914: 350, 2422106: 351, 2422699: 352, 2423022: 353, 2437312: 354, 2437616: 355, 2441942: 356, 2442845: 357, 2443114: 358, 2443484: 359, 2444819: 360, 2445715: 361, 2447366: 362, 2454379: 363, 2457408: 364, 2480495: 365, 2480855: 366, 2481823: 367, 2483362: 368, 2483708: 369, 2484975: 370, 2486261: 371, 2486410: 372, 2487347: 373, 2488291: 374, 2488702: 375, 2489166: 376, 2490219: 377, 2492035: 378, 2492660: 379, 2493509: 380, 2493793: 381, 2494079: 382, 2497673: 383, 2500267: 384, 2504013: 385, 2504458: 386, 2509815: 387, 2510455: 388, 2514041: 389, 2526121: 390, 2536864: 391, 2606052: 392, 2607072: 393, 2640242: 394, 2641379: 395, 2643566: 396, 2655020: 397, 2666196: 398, 2667093: 399, 2669723: 400, 2672831: 401, 2676566: 402, 2687172: 403, 2690373: 404, 2692877: 405, 2699494: 406, 2701002: 407, 2704792: 408, 2708093: 409, 2727426: 410, 2730930: 411, 2747177: 412, 2749479: 413, 2769748: 414, 2776631: 415, 2777292: 416, 2782093: 417, 2783161: 418, 2786058: 419, 2787622: 420, 2788148: 421, 2790996: 422, 2791124: 423, 2791270: 424, 2793495: 425, 2794156: 426, 2795169: 427, 2797295: 428, 2799071: 429, 2802426: 430, 2804414: 431, 2804610: 432, 2807133: 433, 2808304: 434, 2808440: 435, 2814533: 436, 2814860: 437, 2815834: 438, 2817516: 439, 2823428: 440, 2823750: 441, 2825657: 442, 2834397: 443, 2835271: 444, 2837789: 445, 2840245: 446, 2841315: 447, 2843684: 448, 2859443: 449, 2860847: 450, 2865351: 451, 2869837: 452, 2870880: 453, 2871525: 454, 2877765: 455, 2879718: 456, 2883205: 457, 2892201: 458, 2892767: 459, 2894605: 460, 2895154: 461, 2906734: 462, 2909870: 463, 2910353: 464, 2916936: 465, 2917067: 466, 2927161: 467, 2930766: 468, 2939185: 469, 2948072: 470, 2950826: 471, 2951358: 472, 2951585: 473, 2963159: 474, 2965783: 475, 2966193: 476, 2966687: 477, 2971356: 478, 2974003: 479, 2977058: 480, 2978881: 481, 2979186: 482, 2980441: 483, 2981792: 484, 2988304: 485, 2992211: 486, 2992529: 487, 2999410: 488, 3000134: 489, 3000247: 490, 3000684: 491, 3014705: 492, 3016953: 493, 3017168: 494, 3018349: 495, 3026506: 496, 3028079: 497, 3032252: 498, 3041632: 499, 3042490: 500, 3045698: 501, 3047690: 502, 3062245: 503, 3063599: 504, 3063689: 505, 3065424: 506, 3075370: 507, 3085013: 508, 3089624: 509, 3095699: 510, 3100240: 511, 3109150: 512, 3110669: 513, 3124043: 514, 3124170: 515, 3125729: 516, 3126707: 517, 3127747: 518, 3127925: 519, 3131574: 520, 3133878: 521, 3134739: 522, 3141823: 523, 3146219: 524, 3160309: 525, 3179701: 526, 3180011: 527, 3187595: 528, 3188531: 529, 3196217: 530, 3197337: 531, 3201208: 532, 3207743: 533, 3207941: 534, 3208938: 535, 3216828: 536, 3218198: 537, 3220513: 538, 3223299: 539, 3240683: 540, 3249569: 541, 3250847: 542, 3255030: 543, 3259280: 544, 3271574: 545, 3272010: 546, 3272562: 547, 3290653: 548, 3291819: 549, 3297495: 550, 3314780: 551, 3325584: 552, 3337140: 553, 3344393: 554, 3345487: 555, 3347037: 556, 3355925: 557, 3372029: 558, 3376595: 559, 3379051: 560, 3384352: 561, 3388043: 562, 3388183: 563, 3388549: 564, 3393912: 565, 3394916: 566, 3400231: 567, 3404251: 568, 3417042: 569, 3424325: 570, 3425413: 571, 3443371: 572, 3444034: 573, 3445777: 574, 3445924: 575, 3447447: 576, 3447721: 577, 3450230: 578, 3452741: 579, 3457902: 580, 3459775: 581, 3461385: 582, 3467068: 583, 3476684: 584, 3476991: 585, 3478589: 586, 3481172: 587, 3482405: 588, 3483316: 589, 3485407: 590, 3485794: 591, 3492542: 592, 3494278: 593, 3495258: 594, 3496892: 595, 3498962: 596, 3527444: 597, 3529860: 598, 3530642: 599, 3532672: 600, 3534580: 601, 3535780: 602, 3538406: 603, 3544143: 604, 3584254: 605, 3584829: 606, 3590841: 607, 3594734: 608, 3594945: 609, 3595614: 610, 3598930: 611, 3599486: 612, 3602883: 613, 3617480: 614, 3623198: 615, 3627232: 616, 3630383: 617, 3633091: 618, 3637318: 619, 3642806: 620, 3649909: 621, 3657121: 622, 3658185: 623, 3661043: 624, 3662601: 625, 3666591: 626, 3670208: 627, 3673027: 628, 3676483: 629, 3680355: 630, 3690938: 631, 3691459: 632, 3692522: 633, 3697007: 634, 3706229: 635, 3709823: 636, 3710193: 637, 3710637: 638, 3710721: 639, 3717622: 640, 3720891: 641, 3721384: 642, 3724870: 643, 3729826: 644, 3733131: 645, 3733281: 646, 3733805: 647, 3742115: 648, 3743016: 649, 3759954: 650, 3761084: 651, 3763968: 652, 3764736: 653, 3769881: 654, 3770439: 655, 3770679: 656, 3773504: 657, 3775071: 658, 3775546: 659, 3776460: 660, 3777568: 661, 3777754: 662, 3781244: 663, 3782006: 664, 3785016: 665, 3786901: 666, 3787032: 667, 3788195: 668, 3788365: 669, 3791053: 670, 3792782: 671, 3792972: 672, 3793489: 673, 3794056: 674, 3796401: 675, 3803284: 676, 3804744: 677, 3814639: 678, 3814906: 679, 3825788: 680, 3832673: 681, 3837869: 682, 3838899: 683, 3840681: 684, 3841143: 685, 3843555: 686, 3854065: 687, 3857828: 688, 3866082: 689, 3868242: 690, 3868863: 691, 3871628: 692, 3873416: 693, 3874293: 694, 3874599: 695, 3876231: 696, 3877472: 697, 3877845: 698, 3884397: 699, 3887697: 700, 3888257: 701, 3888605: 702, 3891251: 703, 3891332: 704, 3895866: 705, 3899768: 706, 3902125: 707, 3903868: 708, 3908618: 709, 3908714: 710, 3916031: 711, 3920288: 712, 3924679: 713, 3929660: 714, 3929855: 715, 3930313: 716, 3930630: 717, 3933933: 718, 3935335: 719, 3937543: 720, 3938244: 721, 3942813: 722, 3944341: 723, 3947888: 724, 3950228: 725, 3954731: 726, 3956157: 727, 3958227: 728, 3961711: 729, 3967562: 730, 3970156: 731, 3976467: 732, 3976657: 733, 3977966: 734, 3980874: 735, 3982430: 736, 3983396: 737, 3991062: 738, 3992509: 739, 3995372: 740, 3998194: 741, 4004767: 742, 4005630: 743, 4008634: 744, 4009552: 745, 4019541: 746, 4023962: 747, 4026417: 748, 4033901: 749, 4033995: 750, 4037443: 751, 4039381: 752, 4040759: 753, 4041544: 754, 4044716: 755, 4049303: 756, 4065272: 757, 4067472: 758, 4069434: 759, 4070727: 760, 4074963: 761, 4081281: 762, 4086273: 763, 4090263: 764, 4099969: 765, 4111531: 766, 4116512: 767, 4118538: 768, 4118776: 769, 4120489: 770, 4125021: 771, 4127249: 772, 4131690: 773, 4133789: 774, 4136333: 775, 4141076: 776, 4141327: 777, 4141975: 778, 4146614: 779, 4147183: 780, 4149813: 781, 4152593: 782, 4153751: 783, 4154565: 784, 4162706: 785, 4179913: 786, 4192698: 787, 4200800: 788, 4201297: 789, 4204238: 790, 4204347: 791, 4208210: 792, 4209133: 793, 4209239: 794, 4228054: 795, 4229816: 796, 4235860: 797, 4238763: 798, 4239074: 799, 4243546: 800, 4251144: 801, 4252077: 802, 4252225: 803, 4254120: 804, 4254680: 805, 4254777: 806, 4258138: 807, 4259630: 808, 4263257: 809, 4264628: 810, 4265275: 811, 4266014: 812, 4270147: 813, 4273569: 814, 4275548: 815, 4277352: 816, 4285008: 817, 4286575: 818, 4296562: 819, 4310018: 820, 4311004: 821, 4311174: 822, 4317175: 823, 4325704: 824, 4326547: 825, 4328186: 826, 4330267: 827, 4332243: 828, 4335435: 829, 4336792: 830, 4344873: 831, 4346328: 832, 4347754: 833, 4350905: 834, 4355338: 835, 4355933: 836, 4356056: 837, 4357314: 838, 4366367: 839, 4367480: 840, 4370456: 841, 4371430: 842, 4371774: 843, 4372370: 844, 4376876: 845, 4380533: 846, 4389033: 847, 4392985: 848, 4398044: 849, 4399382: 850, 4404412: 851, 4409515: 852, 4417672: 853, 4418357: 854, 4423845: 855, 4428191: 856, 4429376: 857, 4435653: 858, 4442312: 859, 4443257: 860, 4447861: 861, 4456115: 862, 4458633: 863, 4461696: 864, 4462240: 865, 4465501: 866, 4467665: 867, 4476259: 868, 4479046: 869, 4482393: 870, 4483307: 871, 4485082: 872, 4486054: 873, 4487081: 874, 4487394: 875, 4493381: 876, 4501370: 877, 4505470: 878, 4507155: 879, 4509417: 880, 4515003: 881, 4517823: 882, 4522168: 883, 4523525: 884, 4525038: 885, 4525305: 886, 4532106: 887, 4532670: 888, 4536866: 889, 4540053: 890, 4542943: 891, 4548280: 892, 4548362: 893, 4550184: 894, 4552348: 895, 4553703: 896, 4554684: 897, 4557648: 898, 4560804: 899, 4562935: 900, 4579145: 901, 4579432: 902, 4584207: 903, 4589890: 904, 4590129: 905, 4591157: 906, 4591713: 907, 4592741: 908, 4596742: 909, 4597913: 910, 4599235: 911, 4604644: 912, 4606251: 913, 4612504: 914, 4613696: 915, 6359193: 916, 6596364: 917, 6785654: 918, 6794110: 919, 6874185: 920, 7248320: 921, 7565083: 922, 7579787: 923, 7583066: 924, 7584110: 925, 7590611: 926, 7613480: 927, 7614500: 928, 7615774: 929, 7684084: 930, 7693725: 931, 7695742: 932, 7697313: 933, 7697537: 934, 7711569: 935, 7714571: 936, 7714990: 937, 7715103: 938, 7716358: 939, 7716906: 940, 7717410: 941, 7717556: 942, 7718472: 943, 7718747: 944, 7720875: 945, 7730033: 946, 7734744: 947, 7742313: 948, 7745940: 949, 7747607: 950, 7749582: 951, 7753113: 952, 7753275: 953, 7753592: 954, 7754684: 955, 7760859: 956, 7768694: 957, 7802026: 958, 7831146: 959, 7836838: 960, 7860988: 961, 7871810: 962, 7873807: 963, 7875152: 964, 7880968: 965, 7892512: 966, 7920052: 967, 7930864: 968, 7932039: 969, 9193705: 970, 9229709: 971, 9246464: 972, 9256479: 973, 9288635: 974, 9332890: 975, 9399592: 976, 9421951: 977, 9428293: 978, 9468604: 979, 9472597: 980, 9835506: 981, 10148035: 982, 10565667: 983, 11879895: 984, 11939491: 985, 12057211: 986, 12144580: 987, 12267677: 988, 12620546: 989, 12768682: 990, 12985857: 991, 12998815: 992, 13037406: 993, 13040303: 994, 13044778: 995, 13052670: 996, 13054560: 997, 13133613: 998, 15075141: 999}
models/biggan/pytorch_biggan/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # PyTorch
2
+ torch>=0.4.1
3
+ # progress bars in model download and training scripts
4
+ tqdm
5
+ # Accessing files from S3 directly.
6
+ boto3
7
+ # Used for downloading models over HTTP
8
+ requests
models/biggan/pytorch_biggan/scripts/convert_tf_hub_models.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ set -e
9
+ set -x
10
+
11
+ models="128 256 512"
12
+
13
+ mkdir -p models/model_128
14
+ mkdir -p models/model_256
15
+ mkdir -p models/model_512
16
+
17
+ # Convert TF Hub models.
18
+ for model in $models
19
+ do
20
+ pytorch_pretrained_biggan --model_type $model --tf_model_path models/model_$model --pt_save_path models/model_$model
21
+ done
models/biggan/pytorch_biggan/scripts/download_tf_hub_models.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ set -e
9
+ set -x
10
+
11
+ models="128 256 512"
12
+
13
+ mkdir -p models/model_128
14
+ mkdir -p models/model_256
15
+ mkdir -p models/model_512
16
+
17
+ # Download TF Hub models.
18
+ for model in $models
19
+ do
20
+ curl -L "https://tfhub.dev/deepmind/biggan-deep-$model/1?tf-hub-format=compressed" | tar -zxvC models/model_$model
21
+ done
models/biggan/pytorch_biggan/setup.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
3
+
4
+ To create the package for pypi.
5
+
6
+ 1. Change the version in __init__.py and setup.py.
7
+
8
+ 2. Commit these changes with the message: "Release: VERSION"
9
+
10
+ 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' "
11
+ Push the tag to git: git push --tags origin master
12
+
13
+ 4. Build both the sources and the wheel. Do not change anything in setup.py between
14
+ creating the wheel and the source distribution (obviously).
15
+
16
+ For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory.
17
+ (this will build a wheel for the python version you use to build it - make sure you use python 3.x).
18
+
19
+ For the sources, run: "python setup.py sdist"
20
+ You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp.
21
+
22
+ 5. Check that everything looks correct by uploading the package to the pypi test server:
23
+
24
+ twine upload dist/* -r pypitest
25
+ (pypi suggest using twine as other methods upload files via plaintext.)
26
+
27
+ Check that you can install it in a virtualenv by running:
28
+ pip install -i https://testpypi.python.org/pypi allennlp
29
+
30
+ 6. Upload the final version to actual pypi:
31
+ twine upload dist/* -r pypi
32
+
33
+ 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
34
+
35
+ """
36
+ from io import open
37
+ from setuptools import find_packages, setup
38
+
39
+ setup(
40
+ name="pytorch_pretrained_biggan",
41
+ version="0.1.0",
42
+ author="Thomas Wolf",
43
+ author_email="[email protected]",
44
+ description="PyTorch version of DeepMind's BigGAN model with pre-trained models",
45
+ long_description=open("README.md", "r", encoding='utf-8').read(),
46
+ long_description_content_type="text/markdown",
47
+ keywords='BIGGAN GAN deep learning google deepmind',
48
+ license='Apache',
49
+ url="https://github.com/huggingface/pytorch-pretrained-BigGAN",
50
+ packages=find_packages(exclude=["*.tests", "*.tests.*",
51
+ "tests.*", "tests"]),
52
+ install_requires=['torch>=0.4.1',
53
+ 'numpy',
54
+ 'boto3',
55
+ 'requests',
56
+ 'tqdm'],
57
+ tests_require=['pytest'],
58
+ entry_points={
59
+ 'console_scripts': [
60
+ "pytorch_pretrained_biggan=pytorch_pretrained_biggan.convert_tf_to_pytorch:main",
61
+ ]
62
+ },
63
+ classifiers=[
64
+ 'Intended Audience :: Science/Research',
65
+ 'License :: OSI Approved :: Apache Software License',
66
+ 'Programming Language :: Python :: 3',
67
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
68
+ ],
69
+ )
models/stylegan/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Erik Härkönen. All rights reserved.
2
+ # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License. You may obtain a copy
4
+ # of the License at http://www.apache.org/licenses/LICENSE-2.0
5
+
6
+ # Unless required by applicable law or agreed to in writing, software distributed under
7
+ # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8
+ # OF ANY KIND, either express or implied. See the License for the specific language
9
+ # governing permissions and limitations under the License.
10
+
11
+ from pathlib import Path
12
+ import sys
13
+
14
+ #module_path = Path(__file__).parent / 'pytorch_biggan'
15
+ #sys.path.append(str(module_path.resolve()))
16
+
17
+ from .model import StyleGAN_G, NoiseLayer
models/stylegan/model.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Erik Härkönen. All rights reserved.
2
+ # This file is licensed to you under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License. You may obtain a copy
4
+ # of the License at http://www.apache.org/licenses/LICENSE-2.0
5
+
6
+ # Unless required by applicable law or agreed to in writing, software distributed under
7
+ # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
8
+ # OF ANY KIND, either express or implied. See the License for the specific language
9
+ # governing permissions and limitations under the License.
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from collections import OrderedDict
16
+ from pathlib import Path
17
+ import requests
18
+ import pickle
19
+ import sys
20
+
21
+ import numpy as np
22
+
23
+ # Reimplementation of StyleGAN in PyTorch
24
+ # Source: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
25
+
26
+ class MyLinear(nn.Module):
27
+ """Linear layer with equalized learning rate and custom learning rate multiplier."""
28
+ def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):
29
+ super().__init__()
30
+ he_std = gain * input_size**(-0.5) # He init
31
+ # Equalized learning rate and custom learning rate multiplier.
32
+ if use_wscale:
33
+ init_std = 1.0 / lrmul
34
+ self.w_mul = he_std * lrmul
35
+ else:
36
+ init_std = he_std / lrmul
37
+ self.w_mul = lrmul
38
+ self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
39
+ if bias:
40
+ self.bias = torch.nn.Parameter(torch.zeros(output_size))
41
+ self.b_mul = lrmul
42
+ else:
43
+ self.bias = None
44
+
45
+ def forward(self, x):
46
+ bias = self.bias
47
+ if bias is not None:
48
+ bias = bias * self.b_mul
49
+ return F.linear(x, self.weight * self.w_mul, bias)
50
+
51
+ class MyConv2d(nn.Module):
52
+ """Conv layer with equalized learning rate and custom learning rate multiplier."""
53
+ def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,
54
+ intermediate=None, upscale=False):
55
+ super().__init__()
56
+ if upscale:
57
+ self.upscale = Upscale2d()
58
+ else:
59
+ self.upscale = None
60
+ he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
61
+ self.kernel_size = kernel_size
62
+ if use_wscale:
63
+ init_std = 1.0 / lrmul
64
+ self.w_mul = he_std * lrmul
65
+ else:
66
+ init_std = he_std / lrmul
67
+ self.w_mul = lrmul
68
+ self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
69
+ if bias:
70
+ self.bias = torch.nn.Parameter(torch.zeros(output_channels))
71
+ self.b_mul = lrmul
72
+ else:
73
+ self.bias = None
74
+ self.intermediate = intermediate
75
+
76
+ def forward(self, x):
77
+ bias = self.bias
78
+ if bias is not None:
79
+ bias = bias * self.b_mul
80
+
81
+ have_convolution = False
82
+ if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
83
+ # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
84
+ # this really needs to be cleaned up and go into the conv...
85
+ w = self.weight * self.w_mul
86
+ w = w.permute(1, 0, 2, 3)
87
+ # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
88
+ w = F.pad(w, (1,1,1,1))
89
+ w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
90
+ x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2)
91
+ have_convolution = True
92
+ elif self.upscale is not None:
93
+ x = self.upscale(x)
94
+
95
+ if not have_convolution and self.intermediate is None:
96
+ return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2)
97
+ elif not have_convolution:
98
+ x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2)
99
+
100
+ if self.intermediate is not None:
101
+ x = self.intermediate(x)
102
+ if bias is not None:
103
+ x = x + bias.view(1, -1, 1, 1)
104
+ return x
105
+
106
+ class NoiseLayer(nn.Module):
107
+ """adds noise. noise is per pixel (constant over channels) with per-channel weight"""
108
+ def __init__(self, channels):
109
+ super().__init__()
110
+ self.weight = nn.Parameter(torch.zeros(channels))
111
+ self.noise = None
112
+
113
+ def forward(self, x, noise=None):
114
+ if noise is None and self.noise is None:
115
+ noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
116
+ elif noise is None:
117
+ # here is a little trick: if you get all the noiselayers and set each
118
+ # modules .noise attribute, you can have pre-defined noise.
119
+ # Very useful for analysis
120
+ noise = self.noise
121
+ x = x + self.weight.view(1, -1, 1, 1) * noise
122
+ return x
123
+
124
+ class StyleMod(nn.Module):
125
+ def __init__(self, latent_size, channels, use_wscale):
126
+ super(StyleMod, self).__init__()
127
+ self.lin = MyLinear(latent_size,
128
+ channels * 2,
129
+ gain=1.0, use_wscale=use_wscale)
130
+
131
+ def forward(self, x, latent):
132
+ style = self.lin(latent) # style => [batch_size, n_channels*2]
133
+ shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
134
+ style = style.view(shape) # [batch_size, 2, n_channels, ...]
135
+ x = x * (style[:, 0] + 1.) + style[:, 1]
136
+ return x
137
+
138
+ class PixelNormLayer(nn.Module):
139
+ def __init__(self, epsilon=1e-8):
140
+ super().__init__()
141
+ self.epsilon = epsilon
142
+ def forward(self, x):
143
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
144
+
145
+ class BlurLayer(nn.Module):
146
+ def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):
147
+ super(BlurLayer, self).__init__()
148
+ kernel=[1, 2, 1]
149
+ kernel = torch.tensor(kernel, dtype=torch.float32)
150
+ kernel = kernel[:, None] * kernel[None, :]
151
+ kernel = kernel[None, None]
152
+ if normalize:
153
+ kernel = kernel / kernel.sum()
154
+ if flip:
155
+ kernel = kernel[:, :, ::-1, ::-1]
156
+ self.register_buffer('kernel', kernel)
157
+ self.stride = stride
158
+
159
+ def forward(self, x):
160
+ # expand kernel channels
161
+ kernel = self.kernel.expand(x.size(1), -1, -1, -1)
162
+ x = F.conv2d(
163
+ x,
164
+ kernel,
165
+ stride=self.stride,
166
+ padding=int((self.kernel.size(2)-1)/2),
167
+ groups=x.size(1)
168
+ )
169
+ return x
170
+
171
+ def upscale2d(x, factor=2, gain=1):
172
+ assert x.dim() == 4
173
+ if gain != 1:
174
+ x = x * gain
175
+ if factor != 1:
176
+ shape = x.shape
177
+ x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
178
+ x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
179
+ return x
180
+
181
+ class Upscale2d(nn.Module):
182
+ def __init__(self, factor=2, gain=1):
183
+ super().__init__()
184
+ assert isinstance(factor, int) and factor >= 1
185
+ self.gain = gain
186
+ self.factor = factor
187
+ def forward(self, x):
188
+ return upscale2d(x, factor=self.factor, gain=self.gain)
189
+
190
+ class G_mapping(nn.Sequential):
191
+ def __init__(self, nonlinearity='lrelu', use_wscale=True):
192
+ act, gain = {'relu': (torch.relu, np.sqrt(2)),
193
+ 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
194
+ layers = [
195
+ ('pixel_norm', PixelNormLayer()),
196
+ ('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
197
+ ('dense0_act', act),
198
+ ('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
199
+ ('dense1_act', act),
200
+ ('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
201
+ ('dense2_act', act),
202
+ ('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
203
+ ('dense3_act', act),
204
+ ('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
205
+ ('dense4_act', act),
206
+ ('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
207
+ ('dense5_act', act),
208
+ ('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
209
+ ('dense6_act', act),
210
+ ('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
211
+ ('dense7_act', act)
212
+ ]
213
+ super().__init__(OrderedDict(layers))
214
+
215
+ def forward(self, x):
216
+ return super().forward(x)
217
+
218
+ class Truncation(nn.Module):
219
+ def __init__(self, avg_latent, max_layer=8, threshold=0.7):
220
+ super().__init__()
221
+ self.max_layer = max_layer
222
+ self.threshold = threshold
223
+ self.register_buffer('avg_latent', avg_latent)
224
+ def forward(self, x):
225
+ assert x.dim() == 3
226
+ interp = torch.lerp(self.avg_latent, x, self.threshold)
227
+ do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)
228
+ return torch.where(do_trunc, interp, x)
229
+
230
+ class LayerEpilogue(nn.Module):
231
+ """Things to do at the end of each layer."""
232
+ def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
233
+ super().__init__()
234
+ layers = []
235
+ if use_noise:
236
+ layers.append(('noise', NoiseLayer(channels)))
237
+ layers.append(('activation', activation_layer))
238
+ if use_pixel_norm:
239
+ layers.append(('pixel_norm', PixelNorm()))
240
+ if use_instance_norm:
241
+ layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
242
+ self.top_epi = nn.Sequential(OrderedDict(layers))
243
+ if use_styles:
244
+ self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)
245
+ else:
246
+ self.style_mod = None
247
+ def forward(self, x, dlatents_in_slice=None):
248
+ x = self.top_epi(x)
249
+ if self.style_mod is not None:
250
+ x = self.style_mod(x, dlatents_in_slice)
251
+ else:
252
+ assert dlatents_in_slice is None
253
+ return x
254
+
255
+
256
+ class InputBlock(nn.Module):
257
+ def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
258
+ super().__init__()
259
+ self.const_input_layer = const_input_layer
260
+ self.nf = nf
261
+ if self.const_input_layer:
262
+ # called 'const' in tf
263
+ self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
264
+ self.bias = nn.Parameter(torch.ones(nf))
265
+ else:
266
+ self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN
267
+ self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
268
+ self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
269
+ self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
270
+
271
+ def forward(self, dlatents_in_range):
272
+ batch_size = dlatents_in_range.size(0)
273
+ if self.const_input_layer:
274
+ x = self.const.expand(batch_size, -1, -1, -1)
275
+ x = x + self.bias.view(1, -1, 1, 1)
276
+ else:
277
+ x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)
278
+ x = self.epi1(x, dlatents_in_range[:, 0])
279
+ x = self.conv(x)
280
+ x = self.epi2(x, dlatents_in_range[:, 1])
281
+ return x
282
+
283
+
284
+ class GSynthesisBlock(nn.Module):
285
+ def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
286
+ # 2**res x 2**res # res = 3..resolution_log2
287
+ super().__init__()
288
+ if blur_filter:
289
+ blur = BlurLayer(blur_filter)
290
+ else:
291
+ blur = None
292
+ self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
293
+ intermediate=blur, upscale=True)
294
+ self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
295
+ self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)
296
+ self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
297
+
298
+ def forward(self, x, dlatents_in_range):
299
+ x = self.conv0_up(x)
300
+ x = self.epi1(x, dlatents_in_range[:, 0])
301
+ x = self.conv1(x)
302
+ x = self.epi2(x, dlatents_in_range[:, 1])
303
+ return x
304
+
305
+ class G_synthesis(nn.Module):
306
+ def __init__(self,
307
+ dlatent_size = 512, # Disentangled latent (W) dimensionality.
308
+ num_channels = 3, # Number of output color channels.
309
+ resolution = 1024, # Output resolution.
310
+ fmap_base = 8192, # Overall multiplier for the number of feature maps.
311
+ fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
312
+ fmap_max = 512, # Maximum number of feature maps in any layer.
313
+ use_styles = True, # Enable style inputs?
314
+ const_input_layer = True, # First layer is a learned constant?
315
+ use_noise = True, # Enable noise inputs?
316
+ randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
317
+ nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu'
318
+ use_wscale = True, # Enable equalized learning rate?
319
+ use_pixel_norm = False, # Enable pixelwise feature vector normalization?
320
+ use_instance_norm = True, # Enable instance normalization?
321
+ dtype = torch.float32, # Data type to use for activations and outputs.
322
+ blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering.
323
+ ):
324
+
325
+ super().__init__()
326
+ def nf(stage):
327
+ return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
328
+ self.dlatent_size = dlatent_size
329
+ resolution_log2 = int(np.log2(resolution))
330
+ assert resolution == 2**resolution_log2 and resolution >= 4
331
+
332
+ act, gain = {'relu': (torch.relu, np.sqrt(2)),
333
+ 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
334
+ num_layers = resolution_log2 * 2 - 2
335
+ num_styles = num_layers if use_styles else 1
336
+ torgbs = []
337
+ blocks = []
338
+ for res in range(2, resolution_log2 + 1):
339
+ channels = nf(res-1)
340
+ name = '{s}x{s}'.format(s=2**res)
341
+ if res == 2:
342
+ blocks.append((name,
343
+ InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,
344
+ use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
345
+
346
+ else:
347
+ blocks.append((name,
348
+ GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
349
+ last_channels = channels
350
+ self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale)
351
+ self.blocks = nn.ModuleDict(OrderedDict(blocks))
352
+
353
+ def forward(self, dlatents_in):
354
+ # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
355
+ # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
356
+ batch_size = dlatents_in.size(0)
357
+ for i, m in enumerate(self.blocks.values()):
358
+ if i == 0:
359
+ x = m(dlatents_in[:, 2*i:2*i+2])
360
+ else:
361
+ x = m(x, dlatents_in[:, 2*i:2*i+2])
362
+ rgb = self.torgb(x)
363
+ return rgb
364
+
365
+
366
+ class StyleGAN_G(nn.Sequential):
367
+ def __init__(self, resolution, truncation=1.0):
368
+ self.resolution = resolution
369
+ self.layers = OrderedDict([
370
+ ('g_mapping', G_mapping()),
371
+ #('truncation', Truncation(avg_latent)),
372
+ ('g_synthesis', G_synthesis(resolution=resolution)),
373
+ ])
374
+ super().__init__(self.layers)
375
+
376
+ def forward(self, x, latent_is_w=False):
377
+ if isinstance(x, list):
378
+ assert len(x) == 18, 'Must provide 1 or 18 latents'
379
+ if not latent_is_w:
380
+ x = [self.layers['g_mapping'].forward(l) for l in x]
381
+ x = torch.stack(x, dim=1)
382
+ else:
383
+ if not latent_is_w:
384
+ x = self.layers['g_mapping'].forward(x)
385
+ x = x.unsqueeze(1).expand(-1, 18, -1)
386
+
387
+ x = self.layers['g_synthesis'].forward(x)
388
+
389
+ return x
390
+
391
+ # From: https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/
392
+ def load_weights(self, checkpoint):
393
+ self.load_state_dict(torch.load(checkpoint))
394
+
395
+ def export_from_tf(self, pickle_path):
396
+ module_path = Path(__file__).parent / 'stylegan_tf'
397
+ sys.path.append(str(module_path.resolve()))
398
+
399
+ import dnnlib, dnnlib.tflib, pickle, torch, collections
400
+ dnnlib.tflib.init_tf()
401
+
402
+ weights = pickle.load(open(pickle_path,'rb'))
403
+ weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights]
404
+ #torch.save(weights_pt, pytorch_name)
405
+
406
+ # then on the PyTorch side run
407
+ state_G, state_D, state_Gs = weights_pt #torch.load('./karras2019stylegan-ffhq-1024x1024.pt')
408
+ def key_translate(k):
409
+ k = k.lower().split('/')
410
+ if k[0] == 'g_synthesis':
411
+ if not k[1].startswith('torgb'):
412
+ k.insert(1, 'blocks')
413
+ k = '.'.join(k)
414
+ k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin')
415
+ .replace('const.noise.weight','epi1.top_epi.noise.weight')
416
+ .replace('conv.noise.weight','epi2.top_epi.noise.weight')
417
+ .replace('conv.stylemod','epi2.style_mod.lin')
418
+ .replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight')
419
+ .replace('conv0_up.stylemod','epi1.style_mod.lin')
420
+ .replace('conv1.noise.weight', 'epi2.top_epi.noise.weight')
421
+ .replace('conv1.stylemod','epi2.style_mod.lin')
422
+ .replace('torgb_lod0','torgb'))
423
+ else:
424
+ k = '.'.join(k)
425
+ return k
426
+
427
+ def weight_translate(k, w):
428
+ k = key_translate(k)
429
+ if k.endswith('.weight'):
430
+ if w.dim() == 2:
431
+ w = w.t()
432
+ elif w.dim() == 1:
433
+ pass
434
+ else:
435
+ assert w.dim() == 4
436
+ w = w.permute(3, 2, 0, 1)
437
+ return w
438
+
439
+ # we delete the useless torgb filters
440
+ param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)}
441
+ if 1:
442
+ sd_shapes = {k : v.shape for k,v in self.state_dict().items()}
443
+ param_shapes = {k : v.shape for k,v in param_dict.items() }
444
+
445
+ for k in list(sd_shapes)+list(param_shapes):
446
+ pds = param_shapes.get(k)
447
+ sds = sd_shapes.get(k)
448
+ if pds is None:
449
+ print ("sd only", k, sds)
450
+ elif sds is None:
451
+ print ("pd only", k, pds)
452
+ elif sds != pds:
453
+ print ("mismatch!", k, pds, sds)
454
+
455
+ self.load_state_dict(param_dict, strict=False) # needed for the blur kernels
456
+ torch.save(self.state_dict(), Path(pickle_path).with_suffix('.pt'))
models/stylegan/stylegan_tf/LICENSE.txt ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+
3
+
4
+ Attribution-NonCommercial 4.0 International
5
+
6
+ =======================================================================
7
+
8
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
9
+ does not provide legal services or legal advice. Distribution of
10
+ Creative Commons public licenses does not create a lawyer-client or
11
+ other relationship. Creative Commons makes its licenses and related
12
+ information available on an "as-is" basis. Creative Commons gives no
13
+ warranties regarding its licenses, any material licensed under their
14
+ terms and conditions, or any related information. Creative Commons
15
+ disclaims all liability for damages resulting from their use to the
16
+ fullest extent possible.
17
+
18
+ Using Creative Commons Public Licenses
19
+
20
+ Creative Commons public licenses provide a standard set of terms and
21
+ conditions that creators and other rights holders may use to share
22
+ original works of authorship and other material subject to copyright
23
+ and certain other rights specified in the public license below. The
24
+ following considerations are for informational purposes only, are not
25
+ exhaustive, and do not form part of our licenses.
26
+
27
+ Considerations for licensors: Our public licenses are
28
+ intended for use by those authorized to give the public
29
+ permission to use material in ways otherwise restricted by
30
+ copyright and certain other rights. Our licenses are
31
+ irrevocable. Licensors should read and understand the terms
32
+ and conditions of the license they choose before applying it.
33
+ Licensors should also secure all rights necessary before
34
+ applying our licenses so that the public can reuse the
35
+ material as expected. Licensors should clearly mark any
36
+ material not subject to the license. This includes other CC-
37
+ licensed material, or material used under an exception or
38
+ limitation to copyright. More considerations for licensors:
39
+ wiki.creativecommons.org/Considerations_for_licensors
40
+
41
+ Considerations for the public: By using one of our public
42
+ licenses, a licensor grants the public permission to use the
43
+ licensed material under specified terms and conditions. If
44
+ the licensor's permission is not necessary for any reason--for
45
+ example, because of any applicable exception or limitation to
46
+ copyright--then that use is not regulated by the license. Our
47
+ licenses grant only permissions under copyright and certain
48
+ other rights that a licensor has authority to grant. Use of
49
+ the licensed material may still be restricted for other
50
+ reasons, including because others have copyright or other
51
+ rights in the material. A licensor may make special requests,
52
+ such as asking that all changes be marked or described.
53
+ Although not required by our licenses, you are encouraged to
54
+ respect those requests where reasonable. More_considerations
55
+ for the public:
56
+ wiki.creativecommons.org/Considerations_for_licensees
57
+
58
+ =======================================================================
59
+
60
+ Creative Commons Attribution-NonCommercial 4.0 International Public
61
+ License
62
+
63
+ By exercising the Licensed Rights (defined below), You accept and agree
64
+ to be bound by the terms and conditions of this Creative Commons
65
+ Attribution-NonCommercial 4.0 International Public License ("Public
66
+ License"). To the extent this Public License may be interpreted as a
67
+ contract, You are granted the Licensed Rights in consideration of Your
68
+ acceptance of these terms and conditions, and the Licensor grants You
69
+ such rights in consideration of benefits the Licensor receives from
70
+ making the Licensed Material available under these terms and
71
+ conditions.
72
+
73
+
74
+ Section 1 -- Definitions.
75
+
76
+ a. Adapted Material means material subject to Copyright and Similar
77
+ Rights that is derived from or based upon the Licensed Material
78
+ and in which the Licensed Material is translated, altered,
79
+ arranged, transformed, or otherwise modified in a manner requiring
80
+ permission under the Copyright and Similar Rights held by the
81
+ Licensor. For purposes of this Public License, where the Licensed
82
+ Material is a musical work, performance, or sound recording,
83
+ Adapted Material is always produced where the Licensed Material is
84
+ synched in timed relation with a moving image.
85
+
86
+ b. Adapter's License means the license You apply to Your Copyright
87
+ and Similar Rights in Your contributions to Adapted Material in
88
+ accordance with the terms and conditions of this Public License.
89
+
90
+ c. Copyright and Similar Rights means copyright and/or similar rights
91
+ closely related to copyright including, without limitation,
92
+ performance, broadcast, sound recording, and Sui Generis Database
93
+ Rights, without regard to how the rights are labeled or
94
+ categorized. For purposes of this Public License, the rights
95
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
96
+ Rights.
97
+ d. Effective Technological Measures means those measures that, in the
98
+ absence of proper authority, may not be circumvented under laws
99
+ fulfilling obligations under Article 11 of the WIPO Copyright
100
+ Treaty adopted on December 20, 1996, and/or similar international
101
+ agreements.
102
+
103
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
104
+ any other exception or limitation to Copyright and Similar Rights
105
+ that applies to Your use of the Licensed Material.
106
+
107
+ f. Licensed Material means the artistic or literary work, database,
108
+ or other material to which the Licensor applied this Public
109
+ License.
110
+
111
+ g. Licensed Rights means the rights granted to You subject to the
112
+ terms and conditions of this Public License, which are limited to
113
+ all Copyright and Similar Rights that apply to Your use of the
114
+ Licensed Material and that the Licensor has authority to license.
115
+
116
+ h. Licensor means the individual(s) or entity(ies) granting rights
117
+ under this Public License.
118
+
119
+ i. NonCommercial means not primarily intended for or directed towards
120
+ commercial advantage or monetary compensation. For purposes of
121
+ this Public License, the exchange of the Licensed Material for
122
+ other material subject to Copyright and Similar Rights by digital
123
+ file-sharing or similar means is NonCommercial provided there is
124
+ no payment of monetary compensation in connection with the
125
+ exchange.
126
+
127
+ j. Share means to provide material to the public by any means or
128
+ process that requires permission under the Licensed Rights, such
129
+ as reproduction, public display, public performance, distribution,
130
+ dissemination, communication, or importation, and to make material
131
+ available to the public including in ways that members of the
132
+ public may access the material from a place and at a time
133
+ individually chosen by them.
134
+
135
+ k. Sui Generis Database Rights means rights other than copyright
136
+ resulting from Directive 96/9/EC of the European Parliament and of
137
+ the Council of 11 March 1996 on the legal protection of databases,
138
+ as amended and/or succeeded, as well as other essentially
139
+ equivalent rights anywhere in the world.
140
+
141
+ l. You means the individual or entity exercising the Licensed Rights
142
+ under this Public License. Your has a corresponding meaning.
143
+
144
+
145
+ Section 2 -- Scope.
146
+
147
+ a. License grant.
148
+
149
+ 1. Subject to the terms and conditions of this Public License,
150
+ the Licensor hereby grants You a worldwide, royalty-free,
151
+ non-sublicensable, non-exclusive, irrevocable license to
152
+ exercise the Licensed Rights in the Licensed Material to:
153
+
154
+ a. reproduce and Share the Licensed Material, in whole or
155
+ in part, for NonCommercial purposes only; and
156
+
157
+ b. produce, reproduce, and Share Adapted Material for
158
+ NonCommercial purposes only.
159
+
160
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
161
+ Exceptions and Limitations apply to Your use, this Public
162
+ License does not apply, and You do not need to comply with
163
+ its terms and conditions.
164
+
165
+ 3. Term. The term of this Public License is specified in Section
166
+ 6(a).
167
+
168
+ 4. Media and formats; technical modifications allowed. The
169
+ Licensor authorizes You to exercise the Licensed Rights in
170
+ all media and formats whether now known or hereafter created,
171
+ and to make technical modifications necessary to do so. The
172
+ Licensor waives and/or agrees not to assert any right or
173
+ authority to forbid You from making technical modifications
174
+ necessary to exercise the Licensed Rights, including
175
+ technical modifications necessary to circumvent Effective
176
+ Technological Measures. For purposes of this Public License,
177
+ simply making modifications authorized by this Section 2(a)
178
+ (4) never produces Adapted Material.
179
+
180
+ 5. Downstream recipients.
181
+
182
+ a. Offer from the Licensor -- Licensed Material. Every
183
+ recipient of the Licensed Material automatically
184
+ receives an offer from the Licensor to exercise the
185
+ Licensed Rights under the terms and conditions of this
186
+ Public License.
187
+
188
+ b. No downstream restrictions. You may not offer or impose
189
+ any additional or different terms or conditions on, or
190
+ apply any Effective Technological Measures to, the
191
+ Licensed Material if doing so restricts exercise of the
192
+ Licensed Rights by any recipient of the Licensed
193
+ Material.
194
+
195
+ 6. No endorsement. Nothing in this Public License constitutes or
196
+ may be construed as permission to assert or imply that You
197
+ are, or that Your use of the Licensed Material is, connected
198
+ with, or sponsored, endorsed, or granted official status by,
199
+ the Licensor or others designated to receive attribution as
200
+ provided in Section 3(a)(1)(A)(i).
201
+
202
+ b. Other rights.
203
+
204
+ 1. Moral rights, such as the right of integrity, are not
205
+ licensed under this Public License, nor are publicity,
206
+ privacy, and/or other similar personality rights; however, to
207
+ the extent possible, the Licensor waives and/or agrees not to
208
+ assert any such rights held by the Licensor to the limited
209
+ extent necessary to allow You to exercise the Licensed
210
+ Rights, but not otherwise.
211
+
212
+ 2. Patent and trademark rights are not licensed under this
213
+ Public License.
214
+
215
+ 3. To the extent possible, the Licensor waives any right to
216
+ collect royalties from You for the exercise of the Licensed
217
+ Rights, whether directly or through a collecting society
218
+ under any voluntary or waivable statutory or compulsory
219
+ licensing scheme. In all other cases the Licensor expressly
220
+ reserves any right to collect such royalties, including when
221
+ the Licensed Material is used other than for NonCommercial
222
+ purposes.
223
+
224
+
225
+ Section 3 -- License Conditions.
226
+
227
+ Your exercise of the Licensed Rights is expressly made subject to the
228
+ following conditions.
229
+
230
+ a. Attribution.
231
+
232
+ 1. If You Share the Licensed Material (including in modified
233
+ form), You must:
234
+
235
+ a. retain the following if it is supplied by the Licensor
236
+ with the Licensed Material:
237
+
238
+ i. identification of the creator(s) of the Licensed
239
+ Material and any others designated to receive
240
+ attribution, in any reasonable manner requested by
241
+ the Licensor (including by pseudonym if
242
+ designated);
243
+
244
+ ii. a copyright notice;
245
+
246
+ iii. a notice that refers to this Public License;
247
+
248
+ iv. a notice that refers to the disclaimer of
249
+ warranties;
250
+
251
+ v. a URI or hyperlink to the Licensed Material to the
252
+ extent reasonably practicable;
253
+
254
+ b. indicate if You modified the Licensed Material and
255
+ retain an indication of any previous modifications; and
256
+
257
+ c. indicate the Licensed Material is licensed under this
258
+ Public License, and include the text of, or the URI or
259
+ hyperlink to, this Public License.
260
+
261
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
262
+ reasonable manner based on the medium, means, and context in
263
+ which You Share the Licensed Material. For example, it may be
264
+ reasonable to satisfy the conditions by providing a URI or
265
+ hyperlink to a resource that includes the required
266
+ information.
267
+
268
+ 3. If requested by the Licensor, You must remove any of the
269
+ information required by Section 3(a)(1)(A) to the extent
270
+ reasonably practicable.
271
+
272
+ 4. If You Share Adapted Material You produce, the Adapter's
273
+ License You apply must not prevent recipients of the Adapted
274
+ Material from complying with this Public License.
275
+
276
+
277
+ Section 4 -- Sui Generis Database Rights.
278
+
279
+ Where the Licensed Rights include Sui Generis Database Rights that
280
+ apply to Your use of the Licensed Material:
281
+
282
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
283
+ to extract, reuse, reproduce, and Share all or a substantial
284
+ portion of the contents of the database for NonCommercial purposes
285
+ only;
286
+
287
+ b. if You include all or a substantial portion of the database
288
+ contents in a database in which You have Sui Generis Database
289
+ Rights, then the database in which You have Sui Generis Database
290
+ Rights (but not its individual contents) is Adapted Material; and
291
+
292
+ c. You must comply with the conditions in Section 3(a) if You Share
293
+ all or a substantial portion of the contents of the database.
294
+
295
+ For the avoidance of doubt, this Section 4 supplements and does not
296
+ replace Your obligations under this Public License where the Licensed
297
+ Rights include other Copyright and Similar Rights.
298
+
299
+
300
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
301
+
302
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
303
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
304
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
305
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
306
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
307
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
308
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
309
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
310
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
311
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
312
+
313
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
314
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
315
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
316
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
317
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
318
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
319
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
320
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
321
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
322
+
323
+ c. The disclaimer of warranties and limitation of liability provided
324
+ above shall be interpreted in a manner that, to the extent
325
+ possible, most closely approximates an absolute disclaimer and
326
+ waiver of all liability.
327
+
328
+
329
+ Section 6 -- Term and Termination.
330
+
331
+ a. This Public License applies for the term of the Copyright and
332
+ Similar Rights licensed here. However, if You fail to comply with
333
+ this Public License, then Your rights under this Public License
334
+ terminate automatically.
335
+
336
+ b. Where Your right to use the Licensed Material has terminated under
337
+ Section 6(a), it reinstates:
338
+
339
+ 1. automatically as of the date the violation is cured, provided
340
+ it is cured within 30 days of Your discovery of the
341
+ violation; or
342
+
343
+ 2. upon express reinstatement by the Licensor.
344
+
345
+ For the avoidance of doubt, this Section 6(b) does not affect any
346
+ right the Licensor may have to seek remedies for Your violations
347
+ of this Public License.
348
+
349
+ c. For the avoidance of doubt, the Licensor may also offer the
350
+ Licensed Material under separate terms or conditions or stop
351
+ distributing the Licensed Material at any time; however, doing so
352
+ will not terminate this Public License.
353
+
354
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
355
+ License.
356
+
357
+
358
+ Section 7 -- Other Terms and Conditions.
359
+
360
+ a. The Licensor shall not be bound by any additional or different
361
+ terms or conditions communicated by You unless expressly agreed.
362
+
363
+ b. Any arrangements, understandings, or agreements regarding the
364
+ Licensed Material not stated herein are separate from and
365
+ independent of the terms and conditions of this Public License.
366
+
367
+
368
+ Section 8 -- Interpretation.
369
+
370
+ a. For the avoidance of doubt, this Public License does not, and
371
+ shall not be interpreted to, reduce, limit, restrict, or impose
372
+ conditions on any use of the Licensed Material that could lawfully
373
+ be made without permission under this Public License.
374
+
375
+ b. To the extent possible, if any provision of this Public License is
376
+ deemed unenforceable, it shall be automatically reformed to the
377
+ minimum extent necessary to make it enforceable. If the provision
378
+ cannot be reformed, it shall be severed from this Public License
379
+ without affecting the enforceability of the remaining terms and
380
+ conditions.
381
+
382
+ c. No term or condition of this Public License will be waived and no
383
+ failure to comply consented to unless expressly agreed to by the
384
+ Licensor.
385
+
386
+ d. Nothing in this Public License constitutes or may be interpreted
387
+ as a limitation upon, or waiver of, any privileges and immunities
388
+ that apply to the Licensor or You, including from the legal
389
+ processes of any jurisdiction or authority.
390
+
391
+ =======================================================================
392
+
393
+ Creative Commons is not a party to its public
394
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
395
+ its public licenses to material it publishes and in those instances
396
+ will be considered the "Licensor." The text of the Creative Commons
397
+ public licenses is dedicated to the public domain under the CC0 Public
398
+ Domain Dedication. Except for the limited purpose of indicating that
399
+ material is shared under a Creative Commons public license or as
400
+ otherwise permitted by the Creative Commons policies published at
401
+ creativecommons.org/policies, Creative Commons does not authorize the
402
+ use of the trademark "Creative Commons" or any other trademark or logo
403
+ of Creative Commons without its prior written consent including,
404
+ without limitation, in connection with any unauthorized modifications
405
+ to any of its public licenses or any other arrangements,
406
+ understandings, or agreements concerning use of licensed material. For
407
+ the avoidance of doubt, this paragraph does not form part of the
408
+ public licenses.
409
+
410
+ Creative Commons may be contacted at creativecommons.org.
models/stylegan/stylegan_tf/README.md ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## StyleGAN &mdash; Official TensorFlow Implementation
2
+ ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg?style=plastic)
3
+ ![TensorFlow 1.10](https://img.shields.io/badge/tensorflow-1.10-green.svg?style=plastic)
4
+ ![cuDNN 7.3.1](https://img.shields.io/badge/cudnn-7.3.1-green.svg?style=plastic)
5
+ ![License CC BY-NC](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=plastic)
6
+
7
+ ![Teaser image](./stylegan-teaser.png)
8
+ **Picture:** *These people are not real &ndash; they were produced by our generator that allows control over different aspects of the image.*
9
+
10
+ This repository contains the official TensorFlow implementation of the following paper:
11
+
12
+ > **A Style-Based Generator Architecture for Generative Adversarial Networks**<br>
13
+ > Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)<br>
14
+ > https://arxiv.org/abs/1812.04948
15
+ >
16
+ > **Abstract:** *We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.*
17
+
18
+ For business inquiries, please contact [[email protected]](mailto:[email protected])<br>
19
+ For press and other inquiries, please contact Hector Marinez at [[email protected]](mailto:[email protected])<br>
20
+
21
+ **&#9733;&#9733;&#9733; NEW: StyleGAN2 is available at [https://github.com/NVlabs/stylegan2](https://github.com/NVlabs/stylegan2) &#9733;&#9733;&#9733;**
22
+
23
+ ## Resources
24
+
25
+ Material related to our paper is available via the following links:
26
+
27
+ - Paper: https://arxiv.org/abs/1812.04948
28
+ - Video: https://youtu.be/kSLJriaOumA
29
+ - Code: https://github.com/NVlabs/stylegan
30
+ - FFHQ: https://github.com/NVlabs/ffhq-dataset
31
+
32
+ Additional material can be found on Google Drive:
33
+
34
+ | Path | Description
35
+ | :--- | :----------
36
+ | [StyleGAN](https://drive.google.com/open?id=1uka3a1noXHAydRPRbknqwKVGODvnmUBX) | Main folder.
37
+ | &boxvr;&nbsp; [stylegan-paper.pdf](https://drive.google.com/open?id=1v-HkF3Ehrpon7wVIx4r5DLcko_U_V6Lt) | High-quality version of the paper PDF.
38
+ | &boxvr;&nbsp; [stylegan-video.mp4](https://drive.google.com/open?id=1uzwkZHQX_9pYg1i0d1Nbe3D9xPO8-qBf) | High-quality version of the result video.
39
+ | &boxvr;&nbsp; [images](https://drive.google.com/open?id=1-l46akONUWF6LCpDoeq63H53rD7MeiTd) | Example images produced using our generator.
40
+ | &boxv;&nbsp; &boxvr;&nbsp; [representative-images](https://drive.google.com/open?id=1ToY5P4Vvf5_c3TyUizQ8fckFFoFtBvD8) | High-quality images to be used in articles, blog posts, etc.
41
+ | &boxv;&nbsp; &boxur;&nbsp; [100k-generated-images](https://drive.google.com/open?id=100DJ0QXyG89HZzB4w2Cbyf4xjNK54cQ1) | 100,000 generated images for different amounts of truncation.
42
+ | &boxv;&nbsp; &ensp;&ensp; &boxvr;&nbsp; [ffhq-1024x1024](https://drive.google.com/open?id=14lm8VRN1pr4g_KVe6_LvyDX1PObst6d4) | Generated using Flickr-Faces-HQ dataset at 1024&times;1024.
43
+ | &boxv;&nbsp; &ensp;&ensp; &boxvr;&nbsp; [bedrooms-256x256](https://drive.google.com/open?id=1Vxz9fksw4kgjiHrvHkX4Hze4dyThFW6t) | Generated using LSUN Bedroom dataset at 256&times;256.
44
+ | &boxv;&nbsp; &ensp;&ensp; &boxvr;&nbsp; [cars-512x384](https://drive.google.com/open?id=1MFCvOMdLE2_mpeLPTiDw5dxc2CRuKkzS) | Generated using LSUN Car dataset at 512&times;384.
45
+ | &boxv;&nbsp; &ensp;&ensp; &boxur;&nbsp; [cats-256x256](https://drive.google.com/open?id=1gq-Gj3GRFiyghTPKhp8uDMA9HV_0ZFWQ) | Generated using LSUN Cat dataset at 256&times;256.
46
+ | &boxvr;&nbsp; [videos](https://drive.google.com/open?id=1N8pOd_Bf8v89NGUaROdbD8-ayLPgyRRo) | Example videos produced using our generator.
47
+ | &boxv;&nbsp; &boxur;&nbsp; [high-quality-video-clips](https://drive.google.com/open?id=1NFO7_vH0t98J13ckJYFd7kuaTkyeRJ86) | Individual segments of the result video as high-quality MP4.
48
+ | &boxvr;&nbsp; [ffhq-dataset](https://drive.google.com/open?id=1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP) | Raw data for the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset).
49
+ | &boxur;&nbsp; [networks](https://drive.google.com/open?id=1MASQyN5m0voPcx7-9K0r5gObhvvPups7) | Pre-trained networks as pickled instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py).
50
+ | &ensp;&ensp; &boxvr;&nbsp; [stylegan-ffhq-1024x1024.pkl](https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ) | StyleGAN trained with Flickr-Faces-HQ dataset at 1024&times;1024.
51
+ | &ensp;&ensp; &boxvr;&nbsp; [stylegan-celebahq-1024x1024.pkl](https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf) | StyleGAN trained with CelebA-HQ dataset at 1024&times;1024.
52
+ | &ensp;&ensp; &boxvr;&nbsp; [stylegan-bedrooms-256x256.pkl](https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF) | StyleGAN trained with LSUN Bedroom dataset at 256&times;256.
53
+ | &ensp;&ensp; &boxvr;&nbsp; [stylegan-cars-512x384.pkl](https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3) | StyleGAN trained with LSUN Car dataset at 512&times;384.
54
+ | &ensp;&ensp; &boxvr;&nbsp; [stylegan-cats-256x256.pkl](https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ) | StyleGAN trained with LSUN Cat dataset at 256&times;256.
55
+ | &ensp;&ensp; &boxur;&nbsp; [metrics](https://drive.google.com/open?id=1MvYdWCBuMfnoYGptRH-AgKLbPTsIQLhl) | Auxiliary networks for the quality and disentanglement metrics.
56
+ | &ensp;&ensp; &ensp;&ensp; &boxvr;&nbsp; [inception_v3_features.pkl](https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn) | Standard [Inception-v3](https://arxiv.org/abs/1512.00567) classifier that outputs a raw feature vector.
57
+ | &ensp;&ensp; &ensp;&ensp; &boxvr;&nbsp; [vgg16_zhang_perceptual.pkl](https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2) | Standard [LPIPS](https://arxiv.org/abs/1801.03924) metric to estimate perceptual similarity.
58
+ | &ensp;&ensp; &ensp;&ensp; &boxvr;&nbsp; [celebahq-classifier-00-male.pkl](https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX) | Binary classifier trained to detect a single attribute of CelebA-HQ.
59
+ | &ensp;&ensp; &ensp;&ensp; &boxur;&nbsp;&#x22ef; | Please see the file listing for remaining networks.
60
+
61
+ ## Licenses
62
+
63
+ All material, excluding the Flickr-Faces-HQ dataset, is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made.
64
+
65
+ For license information regarding the FFHQ dataset, please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).
66
+
67
+ `inception_v3_features.pkl` and `inception_v3_softmax.pkl` are derived from the pre-trained [Inception-v3](https://arxiv.org/abs/1512.00567) network by Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna. The network was originally shared under [Apache 2.0](https://github.com/tensorflow/models/blob/master/LICENSE) license on the [TensorFlow Models](https://github.com/tensorflow/models) repository.
68
+
69
+ `vgg16.pkl` and `vgg16_zhang_perceptual.pkl` are derived from the pre-trained [VGG-16](https://arxiv.org/abs/1409.1556) network by Karen Simonyan and Andrew Zisserman. The network was originally shared under [Creative Commons BY 4.0](https://creativecommons.org/licenses/by/4.0/) license on the [Very Deep Convolutional Networks for Large-Scale Visual Recognition](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) project page.
70
+
71
+ `vgg16_zhang_perceptual.pkl` is further derived from the pre-trained [LPIPS](https://arxiv.org/abs/1801.03924) weights by Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, and Oliver Wang. The weights were originally shared under [BSD 2-Clause "Simplified" License](https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE) on the [PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity) repository.
72
+
73
+ ## System requirements
74
+
75
+ * Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.
76
+ * 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer.
77
+ * TensorFlow 1.10.0 or newer with GPU support.
78
+ * One or more high-end NVIDIA GPUs with at least 11GB of DRAM. We recommend NVIDIA DGX-1 with 8 Tesla V100 GPUs.
79
+ * NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer.
80
+
81
+ ## Using pre-trained networks
82
+
83
+ A minimal example of using a pre-trained StyleGAN generator is given in [pretrained_example.py](./pretrained_example.py). When executed, the script downloads a pre-trained StyleGAN generator from Google Drive and uses it to generate an image:
84
+
85
+ ```
86
+ > python pretrained_example.py
87
+ Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done
88
+
89
+ Gs Params OutputShape WeightShape
90
+ --- --- --- ---
91
+ latents_in - (?, 512) -
92
+ ...
93
+ images_out - (?, 3, 1024, 1024) -
94
+ --- --- --- ---
95
+ Total 26219627
96
+
97
+ > ls results
98
+ example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP
99
+ ```
100
+
101
+ A more advanced example is given in [generate_figures.py](./generate_figures.py). The script reproduces the figures from our paper in order to illustrate style mixing, noise inputs, and truncation:
102
+ ```
103
+ > python generate_figures.py
104
+ results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu
105
+ results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6
106
+ results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG
107
+ results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_
108
+ results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v
109
+ results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr
110
+ results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke
111
+ results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W
112
+ ```
113
+
114
+ The pre-trained networks are stored as standard pickle files on Google Drive:
115
+
116
+ ```
117
+ # Load pre-trained network.
118
+ url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
119
+ with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
120
+ _G, _D, Gs = pickle.load(f)
121
+ # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
122
+ # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
123
+ # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
124
+ ```
125
+
126
+ The above code downloads the file and unpickles it to yield 3 instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py). To generate images, you will typically want to use `Gs` &ndash; the other two networks are provided for completeness. In order for `pickle.load()` to work, you will need to have the `dnnlib` source directory in your PYTHONPATH and a `tf.Session` set as default. The session can initialized by calling `dnnlib.tflib.init_tf()`.
127
+
128
+ There are three ways to use the pre-trained generator:
129
+
130
+ 1. Use `Gs.run()` for immediate-mode operation where the inputs and outputs are numpy arrays:
131
+ ```
132
+ # Pick latent vector.
133
+ rnd = np.random.RandomState(5)
134
+ latents = rnd.randn(1, Gs.input_shape[1])
135
+
136
+ # Generate image.
137
+ fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
138
+ images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
139
+ ```
140
+ The first argument is a batch of latent vectors of shape `[num, 512]`. The second argument is reserved for class labels (not used by StyleGAN). The remaining keyword arguments are optional and can be used to further modify the operation (see below). The output is a batch of images, whose format is dictated by the `output_transform` argument.
141
+
142
+ 2. Use `Gs.get_output_for()` to incorporate the generator as a part of a larger TensorFlow expression:
143
+ ```
144
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
145
+ images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
146
+ images = tflib.convert_images_to_uint8(images)
147
+ result_expr.append(inception_clone.get_output_for(images))
148
+ ```
149
+ The above code is from [metrics/frechet_inception_distance.py](./metrics/frechet_inception_distance.py). It generates a batch of random images and feeds them directly to the [Inception-v3](https://arxiv.org/abs/1512.00567) network without having to convert the data to numpy arrays in between.
150
+
151
+ 3. Look up `Gs.components.mapping` and `Gs.components.synthesis` to access individual sub-networks of the generator. Similar to `Gs`, the sub-networks are represented as independent instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py):
152
+ ```
153
+ src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
154
+ src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
155
+ src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
156
+ ```
157
+ The above code is from [generate_figures.py](./generate_figures.py). It first transforms a batch of latent vectors into the intermediate *W* space using the mapping network and then turns these vectors into a batch of images using the synthesis network. The `dlatents` array stores a separate copy of the same *w* vector for each layer of the synthesis network to facilitate style mixing.
158
+
159
+ The exact details of the generator are defined in [training/networks_stylegan.py](./training/networks_stylegan.py) (see `G_style`, `G_mapping`, and `G_synthesis`). The following keyword arguments can be specified to modify the behavior when calling `run()` and `get_output_for()`:
160
+
161
+ * `truncation_psi` and `truncation_cutoff` control the truncation trick that that is performed by default when using `Gs` (&psi;=0.7, cutoff=8). It can be disabled by setting `truncation_psi=1` or `is_validation=True`, and the image quality can be further improved at the cost of variation by setting e.g. `truncation_psi=0.5`. Note that truncation is always disabled when using the sub-networks directly. The average *w* needed to manually perform the truncation trick can be looked up using `Gs.get_var('dlatent_avg')`.
162
+
163
+ * `randomize_noise` determines whether to use re-randomize the noise inputs for each generated image (`True`, default) or whether to use specific noise values for the entire minibatch (`False`). The specific values can be accessed via the `tf.Variable` instances that are found using `[var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]`.
164
+
165
+ * When using the mapping network directly, you can specify `dlatent_broadcast=None` to disable the automatic duplication of `dlatents` over the layers of the synthesis network.
166
+
167
+ * Runtime performance can be fine-tuned via `structure='fixed'` and `dtype='float16'`. The former disables support for progressive growing, which is not needed for a fully-trained generator, and the latter performs all computation using half-precision floating point arithmetic.
168
+
169
+ ## Preparing datasets for training
170
+
171
+ The training and evaluation scripts operate on datasets stored as multi-resolution TFRecords. Each dataset is represented by a directory containing the same image data in several resolutions to enable efficient streaming. There is a separate *.tfrecords file for each resolution, and if the dataset contains labels, they are stored in a separate file as well. By default, the scripts expect to find the datasets at `datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords`. The directory can be changed by editing [config.py](./config.py):
172
+
173
+ ```
174
+ result_dir = 'results'
175
+ data_dir = 'datasets'
176
+ cache_dir = 'cache'
177
+ ```
178
+
179
+ To obtain the FFHQ dataset (`datasets/ffhq`), please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).
180
+
181
+ To obtain the CelebA-HQ dataset (`datasets/celebahq`), please refer to the [Progressive GAN repository](https://github.com/tkarras/progressive_growing_of_gans).
182
+
183
+ To obtain other datasets, including LSUN, please consult their corresponding project pages. The datasets can be converted to multi-resolution TFRecords using the provided [dataset_tool.py](./dataset_tool.py):
184
+
185
+ ```
186
+ > python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256
187
+ > python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384
188
+ > python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256
189
+ > python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10
190
+ > python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images
191
+ ```
192
+
193
+ ## Training networks
194
+
195
+ Once the datasets are set up, you can train your own StyleGAN networks as follows:
196
+
197
+ 1. Edit [train.py](./train.py) to specify the dataset and training configuration by uncommenting or editing specific lines.
198
+ 2. Run the training script with `python train.py`.
199
+ 3. The results are written to a newly created directory `results/<ID>-<DESCRIPTION>`.
200
+ 4. The training may take several days (or weeks) to complete, depending on the configuration.
201
+
202
+ By default, `train.py` is configured to train the highest-quality StyleGAN (configuration F in Table 1) for the FFHQ dataset at 1024&times;1024 resolution using 8 GPUs. Please note that we have used 8 GPUs in all of our experiments. Training with fewer GPUs may not produce identical results &ndash; if you wish to compare against our technique, we strongly recommend using the same number of GPUs.
203
+
204
+ Expected training times for the default configuration using Tesla V100 GPUs:
205
+
206
+ | GPUs | 1024&times;1024 | 512&times;512 | 256&times;256 |
207
+ | :--- | :-------------- | :------------ | :------------ |
208
+ | 1 | 41 days 4 hours | 24 days 21 hours | 14 days 22 hours |
209
+ | 2 | 21 days 22 hours | 13 days 7 hours | 9 days 5 hours |
210
+ | 4 | 11 days 8 hours | 7 days 0 hours | 4 days 21 hours |
211
+ | 8 | 6 days 14 hours | 4 days 10 hours | 3 days 8 hours |
212
+
213
+ ## Evaluating quality and disentanglement
214
+
215
+ The quality and disentanglement metrics used in our paper can be evaluated using [run_metrics.py](./run_metrics.py). By default, the script will evaluate the Fr&eacute;chet Inception Distance (`fid50k`) for the pre-trained FFHQ generator and write the results into a newly created directory under `results`. The exact behavior can be changed by uncommenting or editing specific lines in [run_metrics.py](./run_metrics.py).
216
+
217
+ Expected evaluation time and results for the pre-trained FFHQ generator using one Tesla V100 GPU:
218
+
219
+ | Metric | Time | Result | Description
220
+ | :----- | :--- | :----- | :----------
221
+ | fid50k | 16 min | 4.4159 | Fr&eacute;chet Inception Distance using 50,000 images.
222
+ | ppl_zfull | 55 min | 664.8854 | Perceptual Path Length for full paths in *Z*.
223
+ | ppl_wfull | 55 min | 233.3059 | Perceptual Path Length for full paths in *W*.
224
+ | ppl_zend | 55 min | 666.1057 | Perceptual Path Length for path endpoints in *Z*.
225
+ | ppl_wend | 55 min | 197.2266 | Perceptual Path Length for path endpoints in *W*.
226
+ | ls | 10 hours | z: 165.0106<br>w: 3.7447 | Linear Separability in *Z* and *W*.
227
+
228
+ Please note that the exact results may vary from run to run due to the non-deterministic nature of TensorFlow.
229
+
230
+ ## Acknowledgements
231
+
232
+ We thank Jaakko Lehtinen, David Luebke, and Tuomas Kynk&auml;&auml;nniemi for in-depth discussions and helpful comments; Janne Hellsten, Tero Kuosmanen, and Pekka J&auml;nis for compute infrastructure and help with the code release.
models/stylegan/stylegan_tf/config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Global configuration."""
9
+
10
+ #----------------------------------------------------------------------------
11
+ # Paths.
12
+
13
+ result_dir = 'results'
14
+ data_dir = 'datasets'
15
+ cache_dir = 'cache'
16
+ run_dir_ignore = ['results', 'datasets', 'cache']
17
+
18
+ #----------------------------------------------------------------------------
models/stylegan/stylegan_tf/dataset_tool.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN."""
9
+
10
+ # pylint: disable=too-many-lines
11
+ import os
12
+ import sys
13
+ import glob
14
+ import argparse
15
+ import threading
16
+ import six.moves.queue as Queue # pylint: disable=import-error
17
+ import traceback
18
+ import numpy as np
19
+ import tensorflow as tf
20
+ import PIL.Image
21
+ import dnnlib.tflib as tflib
22
+
23
+ from training import dataset
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ def error(msg):
28
+ print('Error: ' + msg)
29
+ exit(1)
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ class TFRecordExporter:
34
+ def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10):
35
+ self.tfrecord_dir = tfrecord_dir
36
+ self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
37
+ self.expected_images = expected_images
38
+ self.cur_images = 0
39
+ self.shape = None
40
+ self.resolution_log2 = None
41
+ self.tfr_writers = []
42
+ self.print_progress = print_progress
43
+ self.progress_interval = progress_interval
44
+
45
+ if self.print_progress:
46
+ print('Creating dataset "%s"' % tfrecord_dir)
47
+ if not os.path.isdir(self.tfrecord_dir):
48
+ os.makedirs(self.tfrecord_dir)
49
+ assert os.path.isdir(self.tfrecord_dir)
50
+
51
+ def close(self):
52
+ if self.print_progress:
53
+ print('%-40s\r' % 'Flushing data...', end='', flush=True)
54
+ for tfr_writer in self.tfr_writers:
55
+ tfr_writer.close()
56
+ self.tfr_writers = []
57
+ if self.print_progress:
58
+ print('%-40s\r' % '', end='', flush=True)
59
+ print('Added %d images.' % self.cur_images)
60
+
61
+ def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order.
62
+ order = np.arange(self.expected_images)
63
+ np.random.RandomState(123).shuffle(order)
64
+ return order
65
+
66
+ def add_image(self, img):
67
+ if self.print_progress and self.cur_images % self.progress_interval == 0:
68
+ print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
69
+ if self.shape is None:
70
+ self.shape = img.shape
71
+ self.resolution_log2 = int(np.log2(self.shape[1]))
72
+ assert self.shape[0] in [1, 3]
73
+ assert self.shape[1] == self.shape[2]
74
+ assert self.shape[1] == 2**self.resolution_log2
75
+ tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
76
+ for lod in range(self.resolution_log2 - 1):
77
+ tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod)
78
+ self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
79
+ assert img.shape == self.shape
80
+ for lod, tfr_writer in enumerate(self.tfr_writers):
81
+ if lod:
82
+ img = img.astype(np.float32)
83
+ img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25
84
+ quant = np.rint(img).clip(0, 255).astype(np.uint8)
85
+ ex = tf.train.Example(features=tf.train.Features(feature={
86
+ 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),
87
+ 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))
88
+ tfr_writer.write(ex.SerializeToString())
89
+ self.cur_images += 1
90
+
91
+ def add_labels(self, labels):
92
+ if self.print_progress:
93
+ print('%-40s\r' % 'Saving labels...', end='', flush=True)
94
+ assert labels.shape[0] == self.cur_images
95
+ with open(self.tfr_prefix + '-rxx.labels', 'wb') as f:
96
+ np.save(f, labels.astype(np.float32))
97
+
98
+ def __enter__(self):
99
+ return self
100
+
101
+ def __exit__(self, *args):
102
+ self.close()
103
+
104
+ #----------------------------------------------------------------------------
105
+
106
+ class ExceptionInfo(object):
107
+ def __init__(self):
108
+ self.value = sys.exc_info()[1]
109
+ self.traceback = traceback.format_exc()
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ class WorkerThread(threading.Thread):
114
+ def __init__(self, task_queue):
115
+ threading.Thread.__init__(self)
116
+ self.task_queue = task_queue
117
+
118
+ def run(self):
119
+ while True:
120
+ func, args, result_queue = self.task_queue.get()
121
+ if func is None:
122
+ break
123
+ try:
124
+ result = func(*args)
125
+ except:
126
+ result = ExceptionInfo()
127
+ result_queue.put((result, args))
128
+
129
+ #----------------------------------------------------------------------------
130
+
131
+ class ThreadPool(object):
132
+ def __init__(self, num_threads):
133
+ assert num_threads >= 1
134
+ self.task_queue = Queue.Queue()
135
+ self.result_queues = dict()
136
+ self.num_threads = num_threads
137
+ for _idx in range(self.num_threads):
138
+ thread = WorkerThread(self.task_queue)
139
+ thread.daemon = True
140
+ thread.start()
141
+
142
+ def add_task(self, func, args=()):
143
+ assert hasattr(func, '__call__') # must be a function
144
+ if func not in self.result_queues:
145
+ self.result_queues[func] = Queue.Queue()
146
+ self.task_queue.put((func, args, self.result_queues[func]))
147
+
148
+ def get_result(self, func): # returns (result, args)
149
+ result, args = self.result_queues[func].get()
150
+ if isinstance(result, ExceptionInfo):
151
+ print('\n\nWorker thread caught an exception:\n' + result.traceback)
152
+ raise result.value
153
+ return result, args
154
+
155
+ def finish(self):
156
+ for _idx in range(self.num_threads):
157
+ self.task_queue.put((None, (), None))
158
+
159
+ def __enter__(self): # for 'with' statement
160
+ return self
161
+
162
+ def __exit__(self, *excinfo):
163
+ self.finish()
164
+
165
+ def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None):
166
+ if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4
167
+ assert max_items_in_flight >= 1
168
+ results = []
169
+ retire_idx = [0]
170
+
171
+ def task_func(prepared, _idx):
172
+ return process_func(prepared)
173
+
174
+ def retire_result():
175
+ processed, (_prepared, idx) = self.get_result(task_func)
176
+ results[idx] = processed
177
+ while retire_idx[0] < len(results) and results[retire_idx[0]] is not None:
178
+ yield post_func(results[retire_idx[0]])
179
+ results[retire_idx[0]] = None
180
+ retire_idx[0] += 1
181
+
182
+ for idx, item in enumerate(item_iterator):
183
+ prepared = pre_func(item)
184
+ results.append(None)
185
+ self.add_task(func=task_func, args=(prepared, idx))
186
+ while retire_idx[0] < idx - max_items_in_flight + 2:
187
+ for res in retire_result(): yield res
188
+ while retire_idx[0] < len(results):
189
+ for res in retire_result(): yield res
190
+
191
+ #----------------------------------------------------------------------------
192
+
193
+ def display(tfrecord_dir):
194
+ print('Loading dataset "%s"' % tfrecord_dir)
195
+ tflib.init_tf({'gpu_options.allow_growth': True})
196
+ dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0)
197
+ tflib.init_uninitialized_vars()
198
+ import cv2 # pip install opencv-python
199
+
200
+ idx = 0
201
+ while True:
202
+ try:
203
+ images, labels = dset.get_minibatch_np(1)
204
+ except tf.errors.OutOfRangeError:
205
+ break
206
+ if idx == 0:
207
+ print('Displaying images')
208
+ cv2.namedWindow('dataset_tool')
209
+ print('Press SPACE or ENTER to advance, ESC to exit')
210
+ print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist()))
211
+ cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR
212
+ idx += 1
213
+ if cv2.waitKey() == 27:
214
+ break
215
+ print('\nDisplayed %d images.' % idx)
216
+
217
+ #----------------------------------------------------------------------------
218
+
219
+ def extract(tfrecord_dir, output_dir):
220
+ print('Loading dataset "%s"' % tfrecord_dir)
221
+ tflib.init_tf({'gpu_options.allow_growth': True})
222
+ dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0)
223
+ tflib.init_uninitialized_vars()
224
+
225
+ print('Extracting images to "%s"' % output_dir)
226
+ if not os.path.isdir(output_dir):
227
+ os.makedirs(output_dir)
228
+ idx = 0
229
+ while True:
230
+ if idx % 10 == 0:
231
+ print('%d\r' % idx, end='', flush=True)
232
+ try:
233
+ images, _labels = dset.get_minibatch_np(1)
234
+ except tf.errors.OutOfRangeError:
235
+ break
236
+ if images.shape[1] == 1:
237
+ img = PIL.Image.fromarray(images[0][0], 'L')
238
+ else:
239
+ img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB')
240
+ img.save(os.path.join(output_dir, 'img%08d.png' % idx))
241
+ idx += 1
242
+ print('Extracted %d images.' % idx)
243
+
244
+ #----------------------------------------------------------------------------
245
+
246
+ def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
247
+ max_label_size = 0 if ignore_labels else 'full'
248
+ print('Loading dataset "%s"' % tfrecord_dir_a)
249
+ tflib.init_tf({'gpu_options.allow_growth': True})
250
+ dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
251
+ print('Loading dataset "%s"' % tfrecord_dir_b)
252
+ dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
253
+ tflib.init_uninitialized_vars()
254
+
255
+ print('Comparing datasets')
256
+ idx = 0
257
+ identical_images = 0
258
+ identical_labels = 0
259
+ while True:
260
+ if idx % 100 == 0:
261
+ print('%d\r' % idx, end='', flush=True)
262
+ try:
263
+ images_a, labels_a = dset_a.get_minibatch_np(1)
264
+ except tf.errors.OutOfRangeError:
265
+ images_a, labels_a = None, None
266
+ try:
267
+ images_b, labels_b = dset_b.get_minibatch_np(1)
268
+ except tf.errors.OutOfRangeError:
269
+ images_b, labels_b = None, None
270
+ if images_a is None or images_b is None:
271
+ if images_a is not None or images_b is not None:
272
+ print('Datasets contain different number of images')
273
+ break
274
+ if images_a.shape == images_b.shape and np.all(images_a == images_b):
275
+ identical_images += 1
276
+ else:
277
+ print('Image %d is different' % idx)
278
+ if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b):
279
+ identical_labels += 1
280
+ else:
281
+ print('Label %d is different' % idx)
282
+ idx += 1
283
+ print('Identical images: %d / %d' % (identical_images, idx))
284
+ if not ignore_labels:
285
+ print('Identical labels: %d / %d' % (identical_labels, idx))
286
+
287
+ #----------------------------------------------------------------------------
288
+
289
+ def create_mnist(tfrecord_dir, mnist_dir):
290
+ print('Loading MNIST from "%s"' % mnist_dir)
291
+ import gzip
292
+ with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
293
+ images = np.frombuffer(file.read(), np.uint8, offset=16)
294
+ with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file:
295
+ labels = np.frombuffer(file.read(), np.uint8, offset=8)
296
+ images = images.reshape(-1, 1, 28, 28)
297
+ images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0)
298
+ assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8
299
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
300
+ assert np.min(images) == 0 and np.max(images) == 255
301
+ assert np.min(labels) == 0 and np.max(labels) == 9
302
+ onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
303
+ onehot[np.arange(labels.size), labels] = 1.0
304
+
305
+ with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
306
+ order = tfr.choose_shuffled_order()
307
+ for idx in range(order.size):
308
+ tfr.add_image(images[order[idx]])
309
+ tfr.add_labels(onehot[order])
310
+
311
+ #----------------------------------------------------------------------------
312
+
313
+ def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123):
314
+ print('Loading MNIST from "%s"' % mnist_dir)
315
+ import gzip
316
+ with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
317
+ images = np.frombuffer(file.read(), np.uint8, offset=16)
318
+ images = images.reshape(-1, 28, 28)
319
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
320
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
321
+ assert np.min(images) == 0 and np.max(images) == 255
322
+
323
+ with TFRecordExporter(tfrecord_dir, num_images) as tfr:
324
+ rnd = np.random.RandomState(random_seed)
325
+ for _idx in range(num_images):
326
+ tfr.add_image(images[rnd.randint(images.shape[0], size=3)])
327
+
328
+ #----------------------------------------------------------------------------
329
+
330
+ def create_cifar10(tfrecord_dir, cifar10_dir):
331
+ print('Loading CIFAR-10 from "%s"' % cifar10_dir)
332
+ import pickle
333
+ images = []
334
+ labels = []
335
+ for batch in range(1, 6):
336
+ with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file:
337
+ data = pickle.load(file, encoding='latin1')
338
+ images.append(data['data'].reshape(-1, 3, 32, 32))
339
+ labels.append(data['labels'])
340
+ images = np.concatenate(images)
341
+ labels = np.concatenate(labels)
342
+ assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
343
+ assert labels.shape == (50000,) and labels.dtype == np.int32
344
+ assert np.min(images) == 0 and np.max(images) == 255
345
+ assert np.min(labels) == 0 and np.max(labels) == 9
346
+ onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
347
+ onehot[np.arange(labels.size), labels] = 1.0
348
+
349
+ with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
350
+ order = tfr.choose_shuffled_order()
351
+ for idx in range(order.size):
352
+ tfr.add_image(images[order[idx]])
353
+ tfr.add_labels(onehot[order])
354
+
355
+ #----------------------------------------------------------------------------
356
+
357
+ def create_cifar100(tfrecord_dir, cifar100_dir):
358
+ print('Loading CIFAR-100 from "%s"' % cifar100_dir)
359
+ import pickle
360
+ with open(os.path.join(cifar100_dir, 'train'), 'rb') as file:
361
+ data = pickle.load(file, encoding='latin1')
362
+ images = data['data'].reshape(-1, 3, 32, 32)
363
+ labels = np.array(data['fine_labels'])
364
+ assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
365
+ assert labels.shape == (50000,) and labels.dtype == np.int32
366
+ assert np.min(images) == 0 and np.max(images) == 255
367
+ assert np.min(labels) == 0 and np.max(labels) == 99
368
+ onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
369
+ onehot[np.arange(labels.size), labels] = 1.0
370
+
371
+ with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
372
+ order = tfr.choose_shuffled_order()
373
+ for idx in range(order.size):
374
+ tfr.add_image(images[order[idx]])
375
+ tfr.add_labels(onehot[order])
376
+
377
+ #----------------------------------------------------------------------------
378
+
379
+ def create_svhn(tfrecord_dir, svhn_dir):
380
+ print('Loading SVHN from "%s"' % svhn_dir)
381
+ import pickle
382
+ images = []
383
+ labels = []
384
+ for batch in range(1, 4):
385
+ with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file:
386
+ data = pickle.load(file, encoding='latin1')
387
+ images.append(data[0])
388
+ labels.append(data[1])
389
+ images = np.concatenate(images)
390
+ labels = np.concatenate(labels)
391
+ assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8
392
+ assert labels.shape == (73257,) and labels.dtype == np.uint8
393
+ assert np.min(images) == 0 and np.max(images) == 255
394
+ assert np.min(labels) == 0 and np.max(labels) == 9
395
+ onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
396
+ onehot[np.arange(labels.size), labels] = 1.0
397
+
398
+ with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
399
+ order = tfr.choose_shuffled_order()
400
+ for idx in range(order.size):
401
+ tfr.add_image(images[order[idx]])
402
+ tfr.add_labels(onehot[order])
403
+
404
+ #----------------------------------------------------------------------------
405
+
406
+ def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None):
407
+ print('Loading LSUN dataset from "%s"' % lmdb_dir)
408
+ import lmdb # pip install lmdb # pylint: disable=import-error
409
+ import cv2 # pip install opencv-python
410
+ import io
411
+ with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
412
+ total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter
413
+ if max_images is None:
414
+ max_images = total_images
415
+ with TFRecordExporter(tfrecord_dir, max_images) as tfr:
416
+ for _idx, (_key, value) in enumerate(txn.cursor()):
417
+ try:
418
+ try:
419
+ img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
420
+ if img is None:
421
+ raise IOError('cv2.imdecode failed')
422
+ img = img[:, :, ::-1] # BGR => RGB
423
+ except IOError:
424
+ img = np.asarray(PIL.Image.open(io.BytesIO(value)))
425
+ crop = np.min(img.shape[:2])
426
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
427
+ img = PIL.Image.fromarray(img, 'RGB')
428
+ img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS)
429
+ img = np.asarray(img)
430
+ img = img.transpose([2, 0, 1]) # HWC => CHW
431
+ tfr.add_image(img)
432
+ except:
433
+ print(sys.exc_info()[1])
434
+ if tfr.cur_images == max_images:
435
+ break
436
+
437
+ #----------------------------------------------------------------------------
438
+
439
+ def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None):
440
+ assert width == 2 ** int(np.round(np.log2(width)))
441
+ assert height <= width
442
+ print('Loading LSUN dataset from "%s"' % lmdb_dir)
443
+ import lmdb # pip install lmdb # pylint: disable=import-error
444
+ import cv2 # pip install opencv-python
445
+ import io
446
+ with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
447
+ total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter
448
+ if max_images is None:
449
+ max_images = total_images
450
+ with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr:
451
+ for idx, (_key, value) in enumerate(txn.cursor()):
452
+ try:
453
+ try:
454
+ img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
455
+ if img is None:
456
+ raise IOError('cv2.imdecode failed')
457
+ img = img[:, :, ::-1] # BGR => RGB
458
+ except IOError:
459
+ img = np.asarray(PIL.Image.open(io.BytesIO(value)))
460
+
461
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
462
+ if img.shape[1] < width or ch < height:
463
+ continue
464
+
465
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
466
+ img = PIL.Image.fromarray(img, 'RGB')
467
+ img = img.resize((width, height), PIL.Image.ANTIALIAS)
468
+ img = np.asarray(img)
469
+ img = img.transpose([2, 0, 1]) # HWC => CHW
470
+
471
+ canvas = np.zeros([3, width, width], dtype=np.uint8)
472
+ canvas[:, (width - height) // 2 : (width + height) // 2] = img
473
+ tfr.add_image(canvas)
474
+ print('\r%d / %d => %d ' % (idx + 1, total_images, tfr.cur_images), end='')
475
+
476
+ except:
477
+ print(sys.exc_info()[1])
478
+ if tfr.cur_images == max_images:
479
+ break
480
+ print()
481
+
482
+ #----------------------------------------------------------------------------
483
+
484
+ def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121):
485
+ print('Loading CelebA from "%s"' % celeba_dir)
486
+ glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png')
487
+ image_filenames = sorted(glob.glob(glob_pattern))
488
+ expected_images = 202599
489
+ if len(image_filenames) != expected_images:
490
+ error('Expected to find %d images' % expected_images)
491
+
492
+ with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
493
+ order = tfr.choose_shuffled_order()
494
+ for idx in range(order.size):
495
+ img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
496
+ assert img.shape == (218, 178, 3)
497
+ img = img[cy - 64 : cy + 64, cx - 64 : cx + 64]
498
+ img = img.transpose(2, 0, 1) # HWC => CHW
499
+ tfr.add_image(img)
500
+
501
+ #----------------------------------------------------------------------------
502
+
503
+ def create_from_images(tfrecord_dir, image_dir, shuffle):
504
+ print('Loading images from "%s"' % image_dir)
505
+ image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
506
+ if len(image_filenames) == 0:
507
+ error('No input images found')
508
+
509
+ img = np.asarray(PIL.Image.open(image_filenames[0]))
510
+ resolution = img.shape[0]
511
+ channels = img.shape[2] if img.ndim == 3 else 1
512
+ if img.shape[1] != resolution:
513
+ error('Input images must have the same width and height')
514
+ if resolution != 2 ** int(np.floor(np.log2(resolution))):
515
+ error('Input image resolution must be a power-of-two')
516
+ if channels not in [1, 3]:
517
+ error('Input images must be stored as RGB or grayscale')
518
+
519
+ with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
520
+ order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
521
+ for idx in range(order.size):
522
+ img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
523
+ if channels == 1:
524
+ img = img[np.newaxis, :, :] # HW => CHW
525
+ else:
526
+ img = img.transpose([2, 0, 1]) # HWC => CHW
527
+ tfr.add_image(img)
528
+
529
+ #----------------------------------------------------------------------------
530
+
531
+ def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle):
532
+ print('Loading HDF5 archive from "%s"' % hdf5_filename)
533
+ import h5py # conda install h5py
534
+ with h5py.File(hdf5_filename, 'r') as hdf5_file:
535
+ hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3])
536
+ with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr:
537
+ order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0])
538
+ for idx in range(order.size):
539
+ tfr.add_image(hdf5_data[order[idx]])
540
+ npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy'
541
+ if os.path.isfile(npy_filename):
542
+ tfr.add_labels(np.load(npy_filename)[order])
543
+
544
+ #----------------------------------------------------------------------------
545
+
546
+ def execute_cmdline(argv):
547
+ prog = argv[0]
548
+ parser = argparse.ArgumentParser(
549
+ prog = prog,
550
+ description = 'Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.',
551
+ epilog = 'Type "%s <command> -h" for more information.' % prog)
552
+
553
+ subparsers = parser.add_subparsers(dest='command')
554
+ subparsers.required = True
555
+ def add_command(cmd, desc, example=None):
556
+ epilog = 'Example: %s %s' % (prog, example) if example is not None else None
557
+ return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)
558
+
559
+ p = add_command( 'display', 'Display images in dataset.',
560
+ 'display datasets/mnist')
561
+ p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
562
+
563
+ p = add_command( 'extract', 'Extract images from dataset.',
564
+ 'extract datasets/mnist mnist-images')
565
+ p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
566
+ p.add_argument( 'output_dir', help='Directory to extract the images into')
567
+
568
+ p = add_command( 'compare', 'Compare two datasets.',
569
+ 'compare datasets/mydataset datasets/mnist')
570
+ p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset')
571
+ p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset')
572
+ p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0)
573
+
574
+ p = add_command( 'create_mnist', 'Create dataset for MNIST.',
575
+ 'create_mnist datasets/mnist ~/downloads/mnist')
576
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
577
+ p.add_argument( 'mnist_dir', help='Directory containing MNIST')
578
+
579
+ p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.',
580
+ 'create_mnistrgb datasets/mnistrgb ~/downloads/mnist')
581
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
582
+ p.add_argument( 'mnist_dir', help='Directory containing MNIST')
583
+ p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000)
584
+ p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123)
585
+
586
+ p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.',
587
+ 'create_cifar10 datasets/cifar10 ~/downloads/cifar10')
588
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
589
+ p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10')
590
+
591
+ p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.',
592
+ 'create_cifar100 datasets/cifar100 ~/downloads/cifar100')
593
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
594
+ p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100')
595
+
596
+ p = add_command( 'create_svhn', 'Create dataset for SVHN.',
597
+ 'create_svhn datasets/svhn ~/downloads/svhn')
598
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
599
+ p.add_argument( 'svhn_dir', help='Directory containing SVHN')
600
+
601
+ p = add_command( 'create_lsun', 'Create dataset for single LSUN category.',
602
+ 'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000')
603
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
604
+ p.add_argument( 'lmdb_dir', help='Directory containing LMDB database')
605
+ p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256)
606
+ p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None)
607
+
608
+ p = add_command( 'create_lsun_wide', 'Create LSUN dataset with non-square aspect ratio.',
609
+ 'create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384')
610
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
611
+ p.add_argument( 'lmdb_dir', help='Directory containing LMDB database')
612
+ p.add_argument( '--width', help='Output width (default: 512)', type=int, default=512)
613
+ p.add_argument( '--height', help='Output height (default: 384)', type=int, default=384)
614
+ p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None)
615
+
616
+ p = add_command( 'create_celeba', 'Create dataset for CelebA.',
617
+ 'create_celeba datasets/celeba ~/downloads/celeba')
618
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
619
+ p.add_argument( 'celeba_dir', help='Directory containing CelebA')
620
+ p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89)
621
+ p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121)
622
+
623
+ p = add_command( 'create_from_images', 'Create dataset from a directory full of images.',
624
+ 'create_from_images datasets/mydataset myimagedir')
625
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
626
+ p.add_argument( 'image_dir', help='Directory containing the images')
627
+ p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
628
+
629
+ p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.',
630
+ 'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5')
631
+ p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
632
+ p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images')
633
+ p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
634
+
635
+ args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h'])
636
+ func = globals()[args.command]
637
+ del args.command
638
+ func(**vars(args))
639
+
640
+ #----------------------------------------------------------------------------
641
+
642
+ if __name__ == "__main__":
643
+ execute_cmdline(sys.argv)
644
+
645
+ #----------------------------------------------------------------------------
models/stylegan/stylegan_tf/dnnlib/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ from . import submission
9
+
10
+ from .submission.run_context import RunContext
11
+
12
+ from .submission.submit import SubmitTarget
13
+ from .submission.submit import PathType
14
+ from .submission.submit import SubmitConfig
15
+ from .submission.submit import get_path_from_template
16
+ from .submission.submit import submit_run
17
+
18
+ from .util import EasyDict
19
+
20
+ submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
models/stylegan/stylegan_tf/dnnlib/submission/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ from . import run_context
9
+ from . import submit
models/stylegan/stylegan_tf/dnnlib/submission/_internal/run.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper for launching run functions in computing clusters.
9
+
10
+ During the submit process, this file is copied to the appropriate run dir.
11
+ When the job is launched in the cluster, this module is the first thing that
12
+ is run inside the docker container.
13
+ """
14
+
15
+ import os
16
+ import pickle
17
+ import sys
18
+
19
+ # PYTHONPATH should have been set so that the run_dir/src is in it
20
+ import dnnlib
21
+
22
+ def main():
23
+ if not len(sys.argv) >= 4:
24
+ raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
25
+
26
+ run_dir = str(sys.argv[1])
27
+ task_name = str(sys.argv[2])
28
+ host_name = str(sys.argv[3])
29
+
30
+ submit_config_path = os.path.join(run_dir, "submit_config.pkl")
31
+
32
+ # SubmitConfig should have been pickled to the run dir
33
+ if not os.path.exists(submit_config_path):
34
+ raise RuntimeError("SubmitConfig pickle file does not exist!")
35
+
36
+ submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
37
+ dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
38
+
39
+ submit_config.task_name = task_name
40
+ submit_config.host_name = host_name
41
+
42
+ dnnlib.submission.submit.run_wrapper(submit_config)
43
+
44
+ if __name__ == "__main__":
45
+ main()
models/stylegan/stylegan_tf/dnnlib/submission/run_context.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helpers for managing the run/training loop."""
9
+
10
+ import datetime
11
+ import json
12
+ import os
13
+ import pprint
14
+ import time
15
+ import types
16
+
17
+ from typing import Any
18
+
19
+ from . import submit
20
+
21
+
22
+ class RunContext(object):
23
+ """Helper class for managing the run/training loop.
24
+
25
+ The context will hide the implementation details of a basic run/training loop.
26
+ It will set things up properly, tell if run should be stopped, and then cleans up.
27
+ User should call update periodically and use should_stop to determine if run should be stopped.
28
+
29
+ Args:
30
+ submit_config: The SubmitConfig that is used for the current run.
31
+ config_module: The whole config module that is used for the current run.
32
+ max_epoch: Optional cached value for the max_epoch variable used in update.
33
+ """
34
+
35
+ def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
36
+ self.submit_config = submit_config
37
+ self.should_stop_flag = False
38
+ self.has_closed = False
39
+ self.start_time = time.time()
40
+ self.last_update_time = time.time()
41
+ self.last_update_interval = 0.0
42
+ self.max_epoch = max_epoch
43
+
44
+ # pretty print the all the relevant content of the config module to a text file
45
+ if config_module is not None:
46
+ with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
47
+ filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
48
+ pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
49
+
50
+ # write out details about the run to a text file
51
+ self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
52
+ with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
53
+ pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
54
+
55
+ def __enter__(self) -> "RunContext":
56
+ return self
57
+
58
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
59
+ self.close()
60
+
61
+ def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
62
+ """Do general housekeeping and keep the state of the context up-to-date.
63
+ Should be called often enough but not in a tight loop."""
64
+ assert not self.has_closed
65
+
66
+ self.last_update_interval = time.time() - self.last_update_time
67
+ self.last_update_time = time.time()
68
+
69
+ if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
70
+ self.should_stop_flag = True
71
+
72
+ max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
73
+
74
+ def should_stop(self) -> bool:
75
+ """Tell whether a stopping condition has been triggered one way or another."""
76
+ return self.should_stop_flag
77
+
78
+ def get_time_since_start(self) -> float:
79
+ """How much time has passed since the creation of the context."""
80
+ return time.time() - self.start_time
81
+
82
+ def get_time_since_last_update(self) -> float:
83
+ """How much time has passed since the last call to update."""
84
+ return time.time() - self.last_update_time
85
+
86
+ def get_last_update_interval(self) -> float:
87
+ """How much time passed between the previous two calls to update."""
88
+ return self.last_update_interval
89
+
90
+ def close(self) -> None:
91
+ """Close the context and clean up.
92
+ Should only be called once."""
93
+ if not self.has_closed:
94
+ # update the run.txt with stopping time
95
+ self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
96
+ with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
97
+ pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
98
+
99
+ self.has_closed = True
models/stylegan/stylegan_tf/dnnlib/submission/submit.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Submit a function to be run either locally or in a computing cluster."""
9
+
10
+ import copy
11
+ import io
12
+ import os
13
+ import pathlib
14
+ import pickle
15
+ import platform
16
+ import pprint
17
+ import re
18
+ import shutil
19
+ import time
20
+ import traceback
21
+
22
+ import zipfile
23
+
24
+ from enum import Enum
25
+
26
+ from .. import util
27
+ from ..util import EasyDict
28
+
29
+
30
+ class SubmitTarget(Enum):
31
+ """The target where the function should be run.
32
+
33
+ LOCAL: Run it locally.
34
+ """
35
+ LOCAL = 1
36
+
37
+
38
+ class PathType(Enum):
39
+ """Determines in which format should a path be formatted.
40
+
41
+ WINDOWS: Format with Windows style.
42
+ LINUX: Format with Linux/Posix style.
43
+ AUTO: Use current OS type to select either WINDOWS or LINUX.
44
+ """
45
+ WINDOWS = 1
46
+ LINUX = 2
47
+ AUTO = 3
48
+
49
+
50
+ _user_name_override = None
51
+
52
+
53
+ class SubmitConfig(util.EasyDict):
54
+ """Strongly typed config dict needed to submit runs.
55
+
56
+ Attributes:
57
+ run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
58
+ run_desc: Description of the run. Will be used in the run dir and task name.
59
+ run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
60
+ run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
61
+ submit_target: Submit target enum value. Used to select where the run is actually launched.
62
+ num_gpus: Number of GPUs used/requested for the run.
63
+ print_info: Whether to print debug information when submitting.
64
+ ask_confirmation: Whether to ask a confirmation before submitting.
65
+ run_id: Automatically populated value during submit.
66
+ run_name: Automatically populated value during submit.
67
+ run_dir: Automatically populated value during submit.
68
+ run_func_name: Automatically populated value during submit.
69
+ run_func_kwargs: Automatically populated value during submit.
70
+ user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
71
+ task_name: Automatically populated value during submit.
72
+ host_name: Automatically populated value during submit.
73
+ """
74
+
75
+ def __init__(self):
76
+ super().__init__()
77
+
78
+ # run (set these)
79
+ self.run_dir_root = "" # should always be passed through get_path_from_template
80
+ self.run_desc = ""
81
+ self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
82
+ self.run_dir_extra_files = None
83
+
84
+ # submit (set these)
85
+ self.submit_target = SubmitTarget.LOCAL
86
+ self.num_gpus = 1
87
+ self.print_info = False
88
+ self.ask_confirmation = False
89
+
90
+ # (automatically populated)
91
+ self.run_id = None
92
+ self.run_name = None
93
+ self.run_dir = None
94
+ self.run_func_name = None
95
+ self.run_func_kwargs = None
96
+ self.user_name = None
97
+ self.task_name = None
98
+ self.host_name = "localhost"
99
+
100
+
101
+ def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
102
+ """Replace tags in the given path template and return either Windows or Linux formatted path."""
103
+ # automatically select path type depending on running OS
104
+ if path_type == PathType.AUTO:
105
+ if platform.system() == "Windows":
106
+ path_type = PathType.WINDOWS
107
+ elif platform.system() == "Linux":
108
+ path_type = PathType.LINUX
109
+ else:
110
+ raise RuntimeError("Unknown platform")
111
+
112
+ path_template = path_template.replace("<USERNAME>", get_user_name())
113
+
114
+ # return correctly formatted path
115
+ if path_type == PathType.WINDOWS:
116
+ return str(pathlib.PureWindowsPath(path_template))
117
+ elif path_type == PathType.LINUX:
118
+ return str(pathlib.PurePosixPath(path_template))
119
+ else:
120
+ raise RuntimeError("Unknown platform")
121
+
122
+
123
+ def get_template_from_path(path: str) -> str:
124
+ """Convert a normal path back to its template representation."""
125
+ # replace all path parts with the template tags
126
+ path = path.replace("\\", "/")
127
+ return path
128
+
129
+
130
+ def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
131
+ """Convert a normal path to template and the convert it back to a normal path with given path type."""
132
+ path_template = get_template_from_path(path)
133
+ path = get_path_from_template(path_template, path_type)
134
+ return path
135
+
136
+
137
+ def set_user_name_override(name: str) -> None:
138
+ """Set the global username override value."""
139
+ global _user_name_override
140
+ _user_name_override = name
141
+
142
+
143
+ def get_user_name():
144
+ """Get the current user name."""
145
+ if _user_name_override is not None:
146
+ return _user_name_override
147
+ elif platform.system() == "Windows":
148
+ return os.getlogin()
149
+ elif platform.system() == "Linux":
150
+ try:
151
+ import pwd # pylint: disable=import-error
152
+ return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
153
+ except:
154
+ return "unknown"
155
+ else:
156
+ raise RuntimeError("Unknown platform")
157
+
158
+
159
+ def _create_run_dir_local(submit_config: SubmitConfig) -> str:
160
+ """Create a new run dir with increasing ID number at the start."""
161
+ run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
162
+
163
+ if not os.path.exists(run_dir_root):
164
+ print("Creating the run dir root: {}".format(run_dir_root))
165
+ os.makedirs(run_dir_root)
166
+
167
+ submit_config.run_id = _get_next_run_id_local(run_dir_root)
168
+ submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
169
+ run_dir = os.path.join(run_dir_root, submit_config.run_name)
170
+
171
+ if os.path.exists(run_dir):
172
+ raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
173
+
174
+ print("Creating the run dir: {}".format(run_dir))
175
+ os.makedirs(run_dir)
176
+
177
+ return run_dir
178
+
179
+
180
+ def _get_next_run_id_local(run_dir_root: str) -> int:
181
+ """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
182
+ dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
183
+ r = re.compile("^\\d+") # match one or more digits at the start of the string
184
+ run_id = 0
185
+
186
+ for dir_name in dir_names:
187
+ m = r.match(dir_name)
188
+
189
+ if m is not None:
190
+ i = int(m.group())
191
+ run_id = max(run_id, i + 1)
192
+
193
+ return run_id
194
+
195
+
196
+ def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
197
+ """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
198
+ print("Copying files to the run dir")
199
+ files = []
200
+
201
+ run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
202
+ assert '.' in submit_config.run_func_name
203
+ for _idx in range(submit_config.run_func_name.count('.') - 1):
204
+ run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
205
+ files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
206
+
207
+ dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
208
+ files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
209
+
210
+ if submit_config.run_dir_extra_files is not None:
211
+ files += submit_config.run_dir_extra_files
212
+
213
+ files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
214
+ files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
215
+
216
+ util.copy_files_and_create_dirs(files)
217
+
218
+ pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
219
+
220
+ with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
221
+ pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
222
+
223
+
224
+ def run_wrapper(submit_config: SubmitConfig) -> None:
225
+ """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
226
+ is_local = submit_config.submit_target == SubmitTarget.LOCAL
227
+
228
+ checker = None
229
+
230
+ # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
231
+ if is_local:
232
+ logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
233
+ else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
234
+ logger = util.Logger(file_name=None, should_flush=True)
235
+
236
+ import dnnlib
237
+ dnnlib.submit_config = submit_config
238
+
239
+ try:
240
+ print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
241
+ start_time = time.time()
242
+ util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
243
+ print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
244
+ except:
245
+ if is_local:
246
+ raise
247
+ else:
248
+ traceback.print_exc()
249
+
250
+ log_src = os.path.join(submit_config.run_dir, "log.txt")
251
+ log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
252
+ shutil.copyfile(log_src, log_dst)
253
+ finally:
254
+ open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
255
+
256
+ dnnlib.submit_config = None
257
+ logger.close()
258
+
259
+ if checker is not None:
260
+ checker.stop()
261
+
262
+
263
+ def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
264
+ """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
265
+ submit_config = copy.copy(submit_config)
266
+
267
+ if submit_config.user_name is None:
268
+ submit_config.user_name = get_user_name()
269
+
270
+ submit_config.run_func_name = run_func_name
271
+ submit_config.run_func_kwargs = run_func_kwargs
272
+
273
+ assert submit_config.submit_target == SubmitTarget.LOCAL
274
+ if submit_config.submit_target in {SubmitTarget.LOCAL}:
275
+ run_dir = _create_run_dir_local(submit_config)
276
+
277
+ submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
278
+ submit_config.run_dir = run_dir
279
+ _populate_run_dir(run_dir, submit_config)
280
+
281
+ if submit_config.print_info:
282
+ print("\nSubmit config:\n")
283
+ pprint.pprint(submit_config, indent=4, width=200, compact=False)
284
+ print()
285
+
286
+ if submit_config.ask_confirmation:
287
+ if not util.ask_yes_no("Continue submitting the job?"):
288
+ return
289
+
290
+ run_wrapper(submit_config)
models/stylegan/stylegan_tf/dnnlib/tflib/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ from . import autosummary
9
+ from . import network
10
+ from . import optimizer
11
+ from . import tfutil
12
+
13
+ from .tfutil import *
14
+ from .network import Network
15
+
16
+ from .optimizer import Optimizer
models/stylegan/stylegan_tf/dnnlib/tflib/autosummary.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper for adding automatically tracked values to Tensorboard.
9
+
10
+ Autosummary creates an identity op that internally keeps track of the input
11
+ values and automatically shows up in TensorBoard. The reported value
12
+ represents an average over input components. The average is accumulated
13
+ constantly over time and flushed when save_summaries() is called.
14
+
15
+ Notes:
16
+ - The output tensor must be used as an input for something else in the
17
+ graph. Otherwise, the autosummary op will not get executed, and the average
18
+ value will not get accumulated.
19
+ - It is perfectly fine to include autosummaries with the same name in
20
+ several places throughout the graph, even if they are executed concurrently.
21
+ - It is ok to also pass in a python scalar or numpy array. In this case, it
22
+ is added to the average immediately.
23
+ """
24
+
25
+ from collections import OrderedDict
26
+ import numpy as np
27
+ import tensorflow as tf
28
+ from tensorboard import summary as summary_lib
29
+ from tensorboard.plugins.custom_scalar import layout_pb2
30
+
31
+ from . import tfutil
32
+ from .tfutil import TfExpression
33
+ from .tfutil import TfExpressionEx
34
+
35
+ _dtype = tf.float64
36
+ _vars = OrderedDict() # name => [var, ...]
37
+ _immediate = OrderedDict() # name => update_op, update_value
38
+ _finalized = False
39
+ _merge_op = None
40
+
41
+
42
+ def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
43
+ """Internal helper for creating autosummary accumulators."""
44
+ assert not _finalized
45
+ name_id = name.replace("/", "_")
46
+ v = tf.cast(value_expr, _dtype)
47
+
48
+ if v.shape.is_fully_defined():
49
+ size = np.prod(tfutil.shape_to_list(v.shape))
50
+ size_expr = tf.constant(size, dtype=_dtype)
51
+ else:
52
+ size = None
53
+ size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
54
+
55
+ if size == 1:
56
+ if v.shape.ndims != 0:
57
+ v = tf.reshape(v, [])
58
+ v = [size_expr, v, tf.square(v)]
59
+ else:
60
+ v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
61
+ v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
62
+
63
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
64
+ var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
65
+ update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
66
+
67
+ if name in _vars:
68
+ _vars[name].append(var)
69
+ else:
70
+ _vars[name] = [var]
71
+ return update_op
72
+
73
+
74
+ def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
75
+ """Create a new autosummary.
76
+
77
+ Args:
78
+ name: Name to use in TensorBoard
79
+ value: TensorFlow expression or python value to track
80
+ passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
81
+
82
+ Example use of the passthru mechanism:
83
+
84
+ n = autosummary('l2loss', loss, passthru=n)
85
+
86
+ This is a shorthand for the following code:
87
+
88
+ with tf.control_dependencies([autosummary('l2loss', loss)]):
89
+ n = tf.identity(n)
90
+ """
91
+ tfutil.assert_tf_initialized()
92
+ name_id = name.replace("/", "_")
93
+
94
+ if tfutil.is_tf_expression(value):
95
+ with tf.name_scope("summary_" + name_id), tf.device(value.device):
96
+ update_op = _create_var(name, value)
97
+ with tf.control_dependencies([update_op]):
98
+ return tf.identity(value if passthru is None else passthru)
99
+
100
+ else: # python scalar or numpy array
101
+ if name not in _immediate:
102
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
103
+ update_value = tf.placeholder(_dtype)
104
+ update_op = _create_var(name, update_value)
105
+ _immediate[name] = update_op, update_value
106
+
107
+ update_op, update_value = _immediate[name]
108
+ tfutil.run(update_op, {update_value: value})
109
+ return value if passthru is None else passthru
110
+
111
+
112
+ def finalize_autosummaries() -> None:
113
+ """Create the necessary ops to include autosummaries in TensorBoard report.
114
+ Note: This should be done only once per graph.
115
+ """
116
+ global _finalized
117
+ tfutil.assert_tf_initialized()
118
+
119
+ if _finalized:
120
+ return None
121
+
122
+ _finalized = True
123
+ tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
124
+
125
+ # Create summary ops.
126
+ with tf.device(None), tf.control_dependencies(None):
127
+ for name, vars_list in _vars.items():
128
+ name_id = name.replace("/", "_")
129
+ with tfutil.absolute_name_scope("Autosummary/" + name_id):
130
+ moments = tf.add_n(vars_list)
131
+ moments /= moments[0]
132
+ with tf.control_dependencies([moments]): # read before resetting
133
+ reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
134
+ with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
135
+ mean = moments[1]
136
+ std = tf.sqrt(moments[2] - tf.square(moments[1]))
137
+ tf.summary.scalar(name, mean)
138
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
139
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
140
+
141
+ # Group by category and chart name.
142
+ cat_dict = OrderedDict()
143
+ for series_name in sorted(_vars.keys()):
144
+ p = series_name.split("/")
145
+ cat = p[0] if len(p) >= 2 else ""
146
+ chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
147
+ if cat not in cat_dict:
148
+ cat_dict[cat] = OrderedDict()
149
+ if chart not in cat_dict[cat]:
150
+ cat_dict[cat][chart] = []
151
+ cat_dict[cat][chart].append(series_name)
152
+
153
+ # Setup custom_scalar layout.
154
+ categories = []
155
+ for cat_name, chart_dict in cat_dict.items():
156
+ charts = []
157
+ for chart_name, series_names in chart_dict.items():
158
+ series = []
159
+ for series_name in series_names:
160
+ series.append(layout_pb2.MarginChartContent.Series(
161
+ value=series_name,
162
+ lower="xCustomScalars/" + series_name + "/margin_lo",
163
+ upper="xCustomScalars/" + series_name + "/margin_hi"))
164
+ margin = layout_pb2.MarginChartContent(series=series)
165
+ charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
166
+ categories.append(layout_pb2.Category(title=cat_name, chart=charts))
167
+ layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
168
+ return layout
169
+
170
+ def save_summaries(file_writer, global_step=None):
171
+ """Call FileWriter.add_summary() with all summaries in the default graph,
172
+ automatically finalizing and merging them on the first call.
173
+ """
174
+ global _merge_op
175
+ tfutil.assert_tf_initialized()
176
+
177
+ if _merge_op is None:
178
+ layout = finalize_autosummaries()
179
+ if layout is not None:
180
+ file_writer.add_summary(layout)
181
+ with tf.device(None), tf.control_dependencies(None):
182
+ _merge_op = tf.summary.merge_all()
183
+
184
+ file_writer.add_summary(_merge_op.eval(), global_step)
models/stylegan/stylegan_tf/dnnlib/tflib/network.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper for managing networks."""
9
+
10
+ import types
11
+ import inspect
12
+ import re
13
+ import uuid
14
+ import sys
15
+ import numpy as np
16
+ import tensorflow as tf
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, List, Tuple, Union
20
+
21
+ from . import tfutil
22
+ from .. import util
23
+
24
+ from .tfutil import TfExpression, TfExpressionEx
25
+
26
+ _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
27
+ _import_module_src = dict() # Source code for temporary modules created during pickle import.
28
+
29
+
30
+ def import_handler(handler_func):
31
+ """Function decorator for declaring custom import handlers."""
32
+ _import_handlers.append(handler_func)
33
+ return handler_func
34
+
35
+
36
+ class Network:
37
+ """Generic network abstraction.
38
+
39
+ Acts as a convenience wrapper for a parameterized network construction
40
+ function, providing several utility methods and convenient access to
41
+ the inputs/outputs/weights.
42
+
43
+ Network objects can be safely pickled and unpickled for long-term
44
+ archival purposes. The pickling works reliably as long as the underlying
45
+ network construction function is defined in a standalone Python module
46
+ that has no side effects or application-specific imports.
47
+
48
+ Args:
49
+ name: Network name. Used to select TensorFlow name and variable scopes.
50
+ func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
51
+ static_kwargs: Keyword arguments to be passed in to the network construction function.
52
+
53
+ Attributes:
54
+ name: User-specified name, defaults to build func name if None.
55
+ scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
56
+ static_kwargs: Arguments passed to the user-supplied build func.
57
+ components: Container for sub-networks. Passed to the build func, and retained between calls.
58
+ num_inputs: Number of input tensors.
59
+ num_outputs: Number of output tensors.
60
+ input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
61
+ output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
62
+ input_shape: Short-hand for input_shapes[0].
63
+ output_shape: Short-hand for output_shapes[0].
64
+ input_templates: Input placeholders in the template graph.
65
+ output_templates: Output tensors in the template graph.
66
+ input_names: Name string for each input.
67
+ output_names: Name string for each output.
68
+ own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
69
+ vars: All variables (local_name => var).
70
+ trainables: All trainable variables (local_name => var).
71
+ var_global_to_local: Mapping from variable global names to local names.
72
+ """
73
+
74
+ def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
75
+ tfutil.assert_tf_initialized()
76
+ assert isinstance(name, str) or name is None
77
+ assert func_name is not None
78
+ assert isinstance(func_name, str) or util.is_top_level_function(func_name)
79
+ assert util.is_pickleable(static_kwargs)
80
+
81
+ self._init_fields()
82
+ self.name = name
83
+ self.static_kwargs = util.EasyDict(static_kwargs)
84
+
85
+ # Locate the user-specified network build function.
86
+ if util.is_top_level_function(func_name):
87
+ func_name = util.get_top_level_function_name(func_name)
88
+ module, self._build_func_name = util.get_module_from_obj_name(func_name)
89
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
90
+ assert callable(self._build_func)
91
+
92
+ # Dig up source code for the module containing the build function.
93
+ self._build_module_src = _import_module_src.get(module, None)
94
+ if self._build_module_src is None:
95
+ self._build_module_src = inspect.getsource(module)
96
+
97
+ # Init TensorFlow graph.
98
+ self._init_graph()
99
+ self.reset_own_vars()
100
+
101
+ def _init_fields(self) -> None:
102
+ self.name = None
103
+ self.scope = None
104
+ self.static_kwargs = util.EasyDict()
105
+ self.components = util.EasyDict()
106
+ self.num_inputs = 0
107
+ self.num_outputs = 0
108
+ self.input_shapes = [[]]
109
+ self.output_shapes = [[]]
110
+ self.input_shape = []
111
+ self.output_shape = []
112
+ self.input_templates = []
113
+ self.output_templates = []
114
+ self.input_names = []
115
+ self.output_names = []
116
+ self.own_vars = OrderedDict()
117
+ self.vars = OrderedDict()
118
+ self.trainables = OrderedDict()
119
+ self.var_global_to_local = OrderedDict()
120
+
121
+ self._build_func = None # User-supplied build function that constructs the network.
122
+ self._build_func_name = None # Name of the build function.
123
+ self._build_module_src = None # Full source code of the module containing the build function.
124
+ self._run_cache = dict() # Cached graph data for Network.run().
125
+
126
+ def _init_graph(self) -> None:
127
+ # Collect inputs.
128
+ self.input_names = []
129
+
130
+ for param in inspect.signature(self._build_func).parameters.values():
131
+ if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
132
+ self.input_names.append(param.name)
133
+
134
+ self.num_inputs = len(self.input_names)
135
+ assert self.num_inputs >= 1
136
+
137
+ # Choose name and scope.
138
+ if self.name is None:
139
+ self.name = self._build_func_name
140
+ assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
141
+ with tf.name_scope(None):
142
+ self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
143
+
144
+ # Finalize build func kwargs.
145
+ build_kwargs = dict(self.static_kwargs)
146
+ build_kwargs["is_template_graph"] = True
147
+ build_kwargs["components"] = self.components
148
+
149
+ # Build template graph.
150
+ with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
151
+ assert tf.get_variable_scope().name == self.scope
152
+ assert tf.get_default_graph().get_name_scope() == self.scope
153
+ with tf.control_dependencies(None): # ignore surrounding control dependencies
154
+ self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
155
+ out_expr = self._build_func(*self.input_templates, **build_kwargs)
156
+
157
+ # Collect outputs.
158
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
159
+ self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
160
+ self.num_outputs = len(self.output_templates)
161
+ assert self.num_outputs >= 1
162
+ assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
163
+
164
+ # Perform sanity checks.
165
+ if any(t.shape.ndims is None for t in self.input_templates):
166
+ raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167
+ if any(t.shape.ndims is None for t in self.output_templates):
168
+ raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169
+ if any(not isinstance(comp, Network) for comp in self.components.values()):
170
+ raise ValueError("Components of a Network must be Networks themselves.")
171
+ if len(self.components) != len(set(comp.name for comp in self.components.values())):
172
+ raise ValueError("Components of a Network must have unique names.")
173
+
174
+ # List inputs and outputs.
175
+ self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
176
+ self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
177
+ self.input_shape = self.input_shapes[0]
178
+ self.output_shape = self.output_shapes[0]
179
+ self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
180
+
181
+ # List variables.
182
+ self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
183
+ self.vars = OrderedDict(self.own_vars)
184
+ self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
185
+ self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
186
+ self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
187
+
188
+ def reset_own_vars(self) -> None:
189
+ """Re-initialize all variables of this network, excluding sub-networks."""
190
+ tfutil.run([var.initializer for var in self.own_vars.values()])
191
+
192
+ def reset_vars(self) -> None:
193
+ """Re-initialize all variables of this network, including sub-networks."""
194
+ tfutil.run([var.initializer for var in self.vars.values()])
195
+
196
+ def reset_trainables(self) -> None:
197
+ """Re-initialize all trainable variables of this network, including sub-networks."""
198
+ tfutil.run([var.initializer for var in self.trainables.values()])
199
+
200
+ def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
201
+ """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
202
+ assert len(in_expr) == self.num_inputs
203
+ assert not all(expr is None for expr in in_expr)
204
+
205
+ # Finalize build func kwargs.
206
+ build_kwargs = dict(self.static_kwargs)
207
+ build_kwargs.update(dynamic_kwargs)
208
+ build_kwargs["is_template_graph"] = False
209
+ build_kwargs["components"] = self.components
210
+
211
+ # Build TensorFlow graph to evaluate the network.
212
+ with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
213
+ assert tf.get_variable_scope().name == self.scope
214
+ valid_inputs = [expr for expr in in_expr if expr is not None]
215
+ final_inputs = []
216
+ for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
217
+ if expr is not None:
218
+ expr = tf.identity(expr, name=name)
219
+ else:
220
+ expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
221
+ final_inputs.append(expr)
222
+ out_expr = self._build_func(*final_inputs, **build_kwargs)
223
+
224
+ # Propagate input shapes back to the user-specified expressions.
225
+ for expr, final in zip(in_expr, final_inputs):
226
+ if isinstance(expr, tf.Tensor):
227
+ expr.set_shape(final.shape)
228
+
229
+ # Express outputs in the desired format.
230
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
231
+ if return_as_list:
232
+ out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
233
+ return out_expr
234
+
235
+ def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
236
+ """Get the local name of a given variable, without any surrounding name scopes."""
237
+ assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
238
+ global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
239
+ return self.var_global_to_local[global_name]
240
+
241
+ def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
242
+ """Find variable by local or global name."""
243
+ assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
244
+ return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
245
+
246
+ def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
247
+ """Get the value of a given variable as NumPy array.
248
+ Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
249
+ return self.find_var(var_or_local_name).eval()
250
+
251
+ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
252
+ """Set the value of a given variable based on the given NumPy array.
253
+ Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
254
+ tfutil.set_vars({self.find_var(var_or_local_name): new_value})
255
+
256
+ def __getstate__(self) -> dict:
257
+ """Pickle export."""
258
+ state = dict()
259
+ state["version"] = 3
260
+ state["name"] = self.name
261
+ state["static_kwargs"] = dict(self.static_kwargs)
262
+ state["components"] = dict(self.components)
263
+ state["build_module_src"] = self._build_module_src
264
+ state["build_func_name"] = self._build_func_name
265
+ state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
266
+ return state
267
+
268
+ def __setstate__(self, state: dict) -> None:
269
+ """Pickle import."""
270
+ # pylint: disable=attribute-defined-outside-init
271
+ tfutil.assert_tf_initialized()
272
+ self._init_fields()
273
+
274
+ # Execute custom import handlers.
275
+ for handler in _import_handlers:
276
+ state = handler(state)
277
+
278
+ # Set basic fields.
279
+ assert state["version"] in [2, 3]
280
+ self.name = state["name"]
281
+ self.static_kwargs = util.EasyDict(state["static_kwargs"])
282
+ self.components = util.EasyDict(state.get("components", {}))
283
+ self._build_module_src = state["build_module_src"]
284
+ self._build_func_name = state["build_func_name"]
285
+
286
+ # Create temporary module from the imported source code.
287
+ module_name = "_tflib_network_import_" + uuid.uuid4().hex
288
+ module = types.ModuleType(module_name)
289
+ sys.modules[module_name] = module
290
+ _import_module_src[module] = self._build_module_src
291
+ exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
292
+
293
+ # Locate network build function in the temporary module.
294
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
295
+ assert callable(self._build_func)
296
+
297
+ # Init TensorFlow graph.
298
+ self._init_graph()
299
+ self.reset_own_vars()
300
+ tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
301
+
302
+ def clone(self, name: str = None, **new_static_kwargs) -> "Network":
303
+ """Create a clone of this network with its own copy of the variables."""
304
+ # pylint: disable=protected-access
305
+ net = object.__new__(Network)
306
+ net._init_fields()
307
+ net.name = name if name is not None else self.name
308
+ net.static_kwargs = util.EasyDict(self.static_kwargs)
309
+ net.static_kwargs.update(new_static_kwargs)
310
+ net._build_module_src = self._build_module_src
311
+ net._build_func_name = self._build_func_name
312
+ net._build_func = self._build_func
313
+ net._init_graph()
314
+ net.copy_vars_from(self)
315
+ return net
316
+
317
+ def copy_own_vars_from(self, src_net: "Network") -> None:
318
+ """Copy the values of all variables from the given network, excluding sub-networks."""
319
+ names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
320
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
321
+
322
+ def copy_vars_from(self, src_net: "Network") -> None:
323
+ """Copy the values of all variables from the given network, including sub-networks."""
324
+ names = [name for name in self.vars.keys() if name in src_net.vars]
325
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
326
+
327
+ def copy_trainables_from(self, src_net: "Network") -> None:
328
+ """Copy the values of all trainable variables from the given network, including sub-networks."""
329
+ names = [name for name in self.trainables.keys() if name in src_net.trainables]
330
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
331
+
332
+ def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
333
+ """Create new network with the given parameters, and copy all variables from this network."""
334
+ if new_name is None:
335
+ new_name = self.name
336
+ static_kwargs = dict(self.static_kwargs)
337
+ static_kwargs.update(new_static_kwargs)
338
+ net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
339
+ net.copy_vars_from(self)
340
+ return net
341
+
342
+ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
343
+ """Construct a TensorFlow op that updates the variables of this network
344
+ to be slightly closer to those of the given network."""
345
+ with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
346
+ ops = []
347
+ for name, var in self.vars.items():
348
+ if name in src_net.vars:
349
+ cur_beta = beta if name in self.trainables else beta_nontrainable
350
+ new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
351
+ ops.append(var.assign(new_value))
352
+ return tf.group(*ops)
353
+
354
+ def run(self,
355
+ *in_arrays: Tuple[Union[np.ndarray, None], ...],
356
+ input_transform: dict = None,
357
+ output_transform: dict = None,
358
+ return_as_list: bool = False,
359
+ print_progress: bool = False,
360
+ minibatch_size: int = None,
361
+ num_gpus: int = 1,
362
+ assume_frozen: bool = False,
363
+ **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
364
+ """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
365
+
366
+ Args:
367
+ input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
368
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the input
369
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
370
+ output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
371
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the output
372
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
373
+ return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
374
+ print_progress: Print progress to the console? Useful for very large input arrays.
375
+ minibatch_size: Maximum minibatch size to use, None = disable batching.
376
+ num_gpus: Number of GPUs to use.
377
+ assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
378
+ dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
379
+ """
380
+ assert len(in_arrays) == self.num_inputs
381
+ assert not all(arr is None for arr in in_arrays)
382
+ assert input_transform is None or util.is_top_level_function(input_transform["func"])
383
+ assert output_transform is None or util.is_top_level_function(output_transform["func"])
384
+ output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
385
+ num_items = in_arrays[0].shape[0]
386
+ if minibatch_size is None:
387
+ minibatch_size = num_items
388
+
389
+ # Construct unique hash key from all arguments that affect the TensorFlow graph.
390
+ key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
391
+ def unwind_key(obj):
392
+ if isinstance(obj, dict):
393
+ return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
394
+ if callable(obj):
395
+ return util.get_top_level_function_name(obj)
396
+ return obj
397
+ key = repr(unwind_key(key))
398
+
399
+ # Build graph.
400
+ if key not in self._run_cache:
401
+ with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
402
+ with tf.device("/cpu:0"):
403
+ in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
404
+ in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
405
+
406
+ out_split = []
407
+ for gpu in range(num_gpus):
408
+ with tf.device("/gpu:%d" % gpu):
409
+ net_gpu = self.clone() if assume_frozen else self
410
+ in_gpu = in_split[gpu]
411
+
412
+ if input_transform is not None:
413
+ in_kwargs = dict(input_transform)
414
+ in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
415
+ in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
416
+
417
+ assert len(in_gpu) == self.num_inputs
418
+ out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
419
+
420
+ if output_transform is not None:
421
+ out_kwargs = dict(output_transform)
422
+ out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
423
+ out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
424
+
425
+ assert len(out_gpu) == self.num_outputs
426
+ out_split.append(out_gpu)
427
+
428
+ with tf.device("/cpu:0"):
429
+ out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
430
+ self._run_cache[key] = in_expr, out_expr
431
+
432
+ # Run minibatches.
433
+ in_expr, out_expr = self._run_cache[key]
434
+ out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
435
+
436
+ for mb_begin in range(0, num_items, minibatch_size):
437
+ if print_progress:
438
+ print("\r%d / %d" % (mb_begin, num_items), end="")
439
+
440
+ mb_end = min(mb_begin + minibatch_size, num_items)
441
+ mb_num = mb_end - mb_begin
442
+ mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
443
+ mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
444
+
445
+ for dst, src in zip(out_arrays, mb_out):
446
+ dst[mb_begin: mb_end] = src
447
+
448
+ # Done.
449
+ if print_progress:
450
+ print("\r%d / %d" % (num_items, num_items))
451
+
452
+ if not return_as_list:
453
+ out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
454
+ return out_arrays
455
+
456
+ def list_ops(self) -> List[TfExpression]:
457
+ include_prefix = self.scope + "/"
458
+ exclude_prefix = include_prefix + "_"
459
+ ops = tf.get_default_graph().get_operations()
460
+ ops = [op for op in ops if op.name.startswith(include_prefix)]
461
+ ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
462
+ return ops
463
+
464
+ def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
465
+ """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
466
+ individual layers of the network. Mainly intended to be used for reporting."""
467
+ layers = []
468
+
469
+ def recurse(scope, parent_ops, parent_vars, level):
470
+ # Ignore specific patterns.
471
+ if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
472
+ return
473
+
474
+ # Filter ops and vars by scope.
475
+ global_prefix = scope + "/"
476
+ local_prefix = global_prefix[len(self.scope) + 1:]
477
+ cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
478
+ cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
479
+ if not cur_ops and not cur_vars:
480
+ return
481
+
482
+ # Filter out all ops related to variables.
483
+ for var in [op for op in cur_ops if op.type.startswith("Variable")]:
484
+ var_prefix = var.name + "/"
485
+ cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
486
+
487
+ # Scope does not contain ops as immediate children => recurse deeper.
488
+ contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
489
+ if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
490
+ visited = set()
491
+ for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
492
+ token = rel_name.split("/")[0]
493
+ if token not in visited:
494
+ recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
495
+ visited.add(token)
496
+ return
497
+
498
+ # Report layer.
499
+ layer_name = scope[len(self.scope) + 1:]
500
+ layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
501
+ layer_trainables = [var for _name, var in cur_vars if var.trainable]
502
+ layers.append((layer_name, layer_output, layer_trainables))
503
+
504
+ recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
505
+ return layers
506
+
507
+ def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
508
+ """Print a summary table of the network structure."""
509
+ rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
510
+ rows += [["---"] * 4]
511
+ total_params = 0
512
+
513
+ for layer_name, layer_output, layer_trainables in self.list_layers():
514
+ num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
515
+ weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
516
+ weights.sort(key=lambda x: len(x.name))
517
+ if len(weights) == 0 and len(layer_trainables) == 1:
518
+ weights = layer_trainables
519
+ total_params += num_params
520
+
521
+ if not hide_layers_with_no_params or num_params != 0:
522
+ num_params_str = str(num_params) if num_params > 0 else "-"
523
+ output_shape_str = str(layer_output.shape)
524
+ weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
525
+ rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
526
+
527
+ rows += [["---"] * 4]
528
+ rows += [["Total", str(total_params), "", ""]]
529
+
530
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
531
+ print()
532
+ for row in rows:
533
+ print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
534
+ print()
535
+
536
+ def setup_weight_histograms(self, title: str = None) -> None:
537
+ """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
538
+ if title is None:
539
+ title = self.name
540
+
541
+ with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
542
+ for local_name, var in self.trainables.items():
543
+ if "/" in local_name:
544
+ p = local_name.split("/")
545
+ name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
546
+ else:
547
+ name = title + "_toplevel/" + local_name
548
+
549
+ tf.summary.histogram(name, var)
550
+
551
+ #----------------------------------------------------------------------------
552
+ # Backwards-compatible emulation of legacy output transformation in Network.run().
553
+
554
+ _print_legacy_warning = True
555
+
556
+ def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
557
+ global _print_legacy_warning
558
+ legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
559
+ if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
560
+ return output_transform, dynamic_kwargs
561
+
562
+ if _print_legacy_warning:
563
+ _print_legacy_warning = False
564
+ print()
565
+ print("WARNING: Old-style output transformations in Network.run() are deprecated.")
566
+ print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
567
+ print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
568
+ print()
569
+ assert output_transform is None
570
+
571
+ new_kwargs = dict(dynamic_kwargs)
572
+ new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
573
+ new_transform["func"] = _legacy_output_transform_func
574
+ return new_transform, new_kwargs
575
+
576
+ def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
577
+ if out_mul != 1.0:
578
+ expr = [x * out_mul for x in expr]
579
+
580
+ if out_add != 0.0:
581
+ expr = [x + out_add for x in expr]
582
+
583
+ if out_shrink > 1:
584
+ ksize = [1, 1, out_shrink, out_shrink]
585
+ expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
586
+
587
+ if out_dtype is not None:
588
+ if tf.as_dtype(out_dtype).is_integer:
589
+ expr = [tf.round(x) for x in expr]
590
+ expr = [tf.saturate_cast(x, out_dtype) for x in expr]
591
+ return expr
models/stylegan/stylegan_tf/dnnlib/tflib/optimizer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper wrapper for a Tensorflow optimizer."""
9
+
10
+ import numpy as np
11
+ import tensorflow as tf
12
+
13
+ from collections import OrderedDict
14
+ from typing import List, Union
15
+
16
+ from . import autosummary
17
+ from . import tfutil
18
+ from .. import util
19
+
20
+ from .tfutil import TfExpression, TfExpressionEx
21
+
22
+ try:
23
+ # TensorFlow 1.13
24
+ from tensorflow.python.ops import nccl_ops
25
+ except:
26
+ # Older TensorFlow versions
27
+ import tensorflow.contrib.nccl as nccl_ops
28
+
29
+ class Optimizer:
30
+ """A Wrapper for tf.train.Optimizer.
31
+
32
+ Automatically takes care of:
33
+ - Gradient averaging for multi-GPU training.
34
+ - Dynamic loss scaling and typecasts for FP16 training.
35
+ - Ignoring corrupted gradients that contain NaNs/Infs.
36
+ - Reporting statistics.
37
+ - Well-chosen default settings.
38
+ """
39
+
40
+ def __init__(self,
41
+ name: str = "Train",
42
+ tf_optimizer: str = "tf.train.AdamOptimizer",
43
+ learning_rate: TfExpressionEx = 0.001,
44
+ use_loss_scaling: bool = False,
45
+ loss_scaling_init: float = 64.0,
46
+ loss_scaling_inc: float = 0.0005,
47
+ loss_scaling_dec: float = 1.0,
48
+ **kwargs):
49
+
50
+ # Init fields.
51
+ self.name = name
52
+ self.learning_rate = tf.convert_to_tensor(learning_rate)
53
+ self.id = self.name.replace("/", ".")
54
+ self.scope = tf.get_default_graph().unique_name(self.id)
55
+ self.optimizer_class = util.get_obj_by_name(tf_optimizer)
56
+ self.optimizer_kwargs = dict(kwargs)
57
+ self.use_loss_scaling = use_loss_scaling
58
+ self.loss_scaling_init = loss_scaling_init
59
+ self.loss_scaling_inc = loss_scaling_inc
60
+ self.loss_scaling_dec = loss_scaling_dec
61
+ self._grad_shapes = None # [shape, ...]
62
+ self._dev_opt = OrderedDict() # device => optimizer
63
+ self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
64
+ self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
65
+ self._updates_applied = False
66
+
67
+ def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
68
+ """Register the gradients of the given loss function with respect to the given variables.
69
+ Intended to be called once per GPU."""
70
+ assert not self._updates_applied
71
+
72
+ # Validate arguments.
73
+ if isinstance(trainable_vars, dict):
74
+ trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
75
+
76
+ assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
77
+ assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
78
+
79
+ if self._grad_shapes is None:
80
+ self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
81
+
82
+ assert len(trainable_vars) == len(self._grad_shapes)
83
+ assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
84
+
85
+ dev = loss.device
86
+
87
+ assert all(var.device == dev for var in trainable_vars)
88
+
89
+ # Register device and compute gradients.
90
+ with tf.name_scope(self.id + "_grad"), tf.device(dev):
91
+ if dev not in self._dev_opt:
92
+ opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
93
+ assert callable(self.optimizer_class)
94
+ self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
95
+ self._dev_grads[dev] = []
96
+
97
+ loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
98
+ grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
99
+ grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
100
+ self._dev_grads[dev].append(grads)
101
+
102
+ def apply_updates(self) -> tf.Operation:
103
+ """Construct training op to update the registered variables based on their gradients."""
104
+ tfutil.assert_tf_initialized()
105
+ assert not self._updates_applied
106
+ self._updates_applied = True
107
+ devices = list(self._dev_grads.keys())
108
+ total_grads = sum(len(grads) for grads in self._dev_grads.values())
109
+ assert len(devices) >= 1 and total_grads >= 1
110
+ ops = []
111
+
112
+ with tfutil.absolute_name_scope(self.scope):
113
+ # Cast gradients to FP32 and calculate partial sum within each device.
114
+ dev_grads = OrderedDict() # device => [(grad, var), ...]
115
+
116
+ for dev_idx, dev in enumerate(devices):
117
+ with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
118
+ sums = []
119
+
120
+ for gv in zip(*self._dev_grads[dev]):
121
+ assert all(v is gv[0][1] for g, v in gv)
122
+ g = [tf.cast(g, tf.float32) for g, v in gv]
123
+ g = g[0] if len(g) == 1 else tf.add_n(g)
124
+ sums.append((g, gv[0][1]))
125
+
126
+ dev_grads[dev] = sums
127
+
128
+ # Sum gradients across devices.
129
+ if len(devices) > 1:
130
+ with tf.name_scope("SumAcrossGPUs"), tf.device(None):
131
+ for var_idx, grad_shape in enumerate(self._grad_shapes):
132
+ g = [dev_grads[dev][var_idx][0] for dev in devices]
133
+
134
+ if np.prod(grad_shape): # nccl does not support zero-sized tensors
135
+ g = nccl_ops.all_sum(g)
136
+
137
+ for dev, gg in zip(devices, g):
138
+ dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
139
+
140
+ # Apply updates separately on each device.
141
+ for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
142
+ with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
143
+ # Scale gradients as needed.
144
+ if self.use_loss_scaling or total_grads > 1:
145
+ with tf.name_scope("Scale"):
146
+ coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
147
+ coef = self.undo_loss_scaling(coef)
148
+ grads = [(g * coef, v) for g, v in grads]
149
+
150
+ # Check for overflows.
151
+ with tf.name_scope("CheckOverflow"):
152
+ grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
153
+
154
+ # Update weights and adjust loss scaling.
155
+ with tf.name_scope("UpdateWeights"):
156
+ # pylint: disable=cell-var-from-loop
157
+ opt = self._dev_opt[dev]
158
+ ls_var = self.get_loss_scaling_var(dev)
159
+
160
+ if not self.use_loss_scaling:
161
+ ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
162
+ else:
163
+ ops.append(tf.cond(grad_ok,
164
+ lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
165
+ lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
166
+
167
+ # Report statistics on the last device.
168
+ if dev == devices[-1]:
169
+ with tf.name_scope("Statistics"):
170
+ ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
171
+ ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
172
+
173
+ if self.use_loss_scaling:
174
+ ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
175
+
176
+ # Initialize variables and group everything into a single op.
177
+ self.reset_optimizer_state()
178
+ tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
179
+
180
+ return tf.group(*ops, name="TrainingOp")
181
+
182
+ def reset_optimizer_state(self) -> None:
183
+ """Reset internal state of the underlying optimizer."""
184
+ tfutil.assert_tf_initialized()
185
+ tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
186
+
187
+ def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
188
+ """Get or create variable representing log2 of the current dynamic loss scaling factor."""
189
+ if not self.use_loss_scaling:
190
+ return None
191
+
192
+ if device not in self._dev_ls_var:
193
+ with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
194
+ self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
195
+
196
+ return self._dev_ls_var[device]
197
+
198
+ def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
199
+ """Apply dynamic loss scaling for the given expression."""
200
+ assert tfutil.is_tf_expression(value)
201
+
202
+ if not self.use_loss_scaling:
203
+ return value
204
+
205
+ return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
206
+
207
+ def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
208
+ """Undo the effect of dynamic loss scaling for the given expression."""
209
+ assert tfutil.is_tf_expression(value)
210
+
211
+ if not self.use_loss_scaling:
212
+ return value
213
+
214
+ return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
models/stylegan/stylegan_tf/dnnlib/tflib/tfutil.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Miscellaneous helper utils for Tensorflow."""
9
+
10
+ import os
11
+ import numpy as np
12
+ import tensorflow as tf
13
+
14
+ from typing import Any, Iterable, List, Union
15
+
16
+ TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
17
+ """A type that represents a valid Tensorflow expression."""
18
+
19
+ TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
20
+ """A type that can be converted to a valid Tensorflow expression."""
21
+
22
+
23
+ def run(*args, **kwargs) -> Any:
24
+ """Run the specified ops in the default session."""
25
+ assert_tf_initialized()
26
+ return tf.get_default_session().run(*args, **kwargs)
27
+
28
+
29
+ def is_tf_expression(x: Any) -> bool:
30
+ """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
31
+ return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
32
+
33
+
34
+ def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
35
+ """Convert a Tensorflow shape to a list of ints."""
36
+ return [dim.value for dim in shape]
37
+
38
+
39
+ def flatten(x: TfExpressionEx) -> TfExpression:
40
+ """Shortcut function for flattening a tensor."""
41
+ with tf.name_scope("Flatten"):
42
+ return tf.reshape(x, [-1])
43
+
44
+
45
+ def log2(x: TfExpressionEx) -> TfExpression:
46
+ """Logarithm in base 2."""
47
+ with tf.name_scope("Log2"):
48
+ return tf.log(x) * np.float32(1.0 / np.log(2.0))
49
+
50
+
51
+ def exp2(x: TfExpressionEx) -> TfExpression:
52
+ """Exponent in base 2."""
53
+ with tf.name_scope("Exp2"):
54
+ return tf.exp(x * np.float32(np.log(2.0)))
55
+
56
+
57
+ def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
58
+ """Linear interpolation."""
59
+ with tf.name_scope("Lerp"):
60
+ return a + (b - a) * t
61
+
62
+
63
+ def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
64
+ """Linear interpolation with clip."""
65
+ with tf.name_scope("LerpClip"):
66
+ return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
67
+
68
+
69
+ def absolute_name_scope(scope: str) -> tf.name_scope:
70
+ """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
71
+ return tf.name_scope(scope + "/")
72
+
73
+
74
+ def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
75
+ """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
76
+ return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
77
+
78
+
79
+ def _sanitize_tf_config(config_dict: dict = None) -> dict:
80
+ # Defaults.
81
+ cfg = dict()
82
+ cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
83
+ cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
84
+ cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
85
+ cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
86
+ cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
87
+
88
+ # User overrides.
89
+ if config_dict is not None:
90
+ cfg.update(config_dict)
91
+ return cfg
92
+
93
+
94
+ def init_tf(config_dict: dict = None) -> None:
95
+ """Initialize TensorFlow session using good default settings."""
96
+ # Skip if already initialized.
97
+ if tf.get_default_session() is not None:
98
+ return
99
+
100
+ # Setup config dict and random seeds.
101
+ cfg = _sanitize_tf_config(config_dict)
102
+ np_random_seed = cfg["rnd.np_random_seed"]
103
+ if np_random_seed is not None:
104
+ np.random.seed(np_random_seed)
105
+ tf_random_seed = cfg["rnd.tf_random_seed"]
106
+ if tf_random_seed == "auto":
107
+ tf_random_seed = np.random.randint(1 << 31)
108
+ if tf_random_seed is not None:
109
+ tf.set_random_seed(tf_random_seed)
110
+
111
+ # Setup environment variables.
112
+ for key, value in list(cfg.items()):
113
+ fields = key.split(".")
114
+ if fields[0] == "env":
115
+ assert len(fields) == 2
116
+ os.environ[fields[1]] = str(value)
117
+
118
+ # Create default TensorFlow session.
119
+ create_session(cfg, force_as_default=True)
120
+
121
+
122
+ def assert_tf_initialized():
123
+ """Check that TensorFlow session has been initialized."""
124
+ if tf.get_default_session() is None:
125
+ raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
126
+
127
+
128
+ def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
129
+ """Create tf.Session based on config dict."""
130
+ # Setup TensorFlow config proto.
131
+ cfg = _sanitize_tf_config(config_dict)
132
+ config_proto = tf.ConfigProto()
133
+ for key, value in cfg.items():
134
+ fields = key.split(".")
135
+ if fields[0] not in ["rnd", "env"]:
136
+ obj = config_proto
137
+ for field in fields[:-1]:
138
+ obj = getattr(obj, field)
139
+ setattr(obj, fields[-1], value)
140
+
141
+ # Create session.
142
+ session = tf.Session(config=config_proto)
143
+ if force_as_default:
144
+ # pylint: disable=protected-access
145
+ session._default_session = session.as_default()
146
+ session._default_session.enforce_nesting = False
147
+ session._default_session.__enter__() # pylint: disable=no-member
148
+
149
+ return session
150
+
151
+
152
+ def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
153
+ """Initialize all tf.Variables that have not already been initialized.
154
+
155
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
156
+ tf.variables_initializer(tf.report_uninitialized_variables()).run()
157
+ """
158
+ assert_tf_initialized()
159
+ if target_vars is None:
160
+ target_vars = tf.global_variables()
161
+
162
+ test_vars = []
163
+ test_ops = []
164
+
165
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
166
+ for var in target_vars:
167
+ assert is_tf_expression(var)
168
+
169
+ try:
170
+ tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
171
+ except KeyError:
172
+ # Op does not exist => variable may be uninitialized.
173
+ test_vars.append(var)
174
+
175
+ with absolute_name_scope(var.name.split(":")[0]):
176
+ test_ops.append(tf.is_variable_initialized(var))
177
+
178
+ init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
179
+ run([var.initializer for var in init_vars])
180
+
181
+
182
+ def set_vars(var_to_value_dict: dict) -> None:
183
+ """Set the values of given tf.Variables.
184
+
185
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
186
+ tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
187
+ """
188
+ assert_tf_initialized()
189
+ ops = []
190
+ feed_dict = {}
191
+
192
+ for var, value in var_to_value_dict.items():
193
+ assert is_tf_expression(var)
194
+
195
+ try:
196
+ setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
197
+ except KeyError:
198
+ with absolute_name_scope(var.name.split(":")[0]):
199
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
200
+ setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
201
+
202
+ ops.append(setter)
203
+ feed_dict[setter.op.inputs[1]] = value
204
+
205
+ run(ops, feed_dict)
206
+
207
+
208
+ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
209
+ """Create tf.Variable with large initial value without bloating the tf graph."""
210
+ assert_tf_initialized()
211
+ assert isinstance(initial_value, np.ndarray)
212
+ zeros = tf.zeros(initial_value.shape, initial_value.dtype)
213
+ var = tf.Variable(zeros, *args, **kwargs)
214
+ set_vars({var: initial_value})
215
+ return var
216
+
217
+
218
+ def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
219
+ """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
220
+ Can be used as an input transformation for Network.run().
221
+ """
222
+ images = tf.cast(images, tf.float32)
223
+ if nhwc_to_nchw:
224
+ images = tf.transpose(images, [0, 3, 1, 2])
225
+ return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
226
+
227
+
228
+ def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
229
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
230
+ Can be used as an output transformation for Network.run().
231
+ """
232
+ images = tf.cast(images, tf.float32)
233
+ if shrink > 1:
234
+ ksize = [1, 1, shrink, shrink]
235
+ images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
236
+ if nchw_to_nhwc:
237
+ images = tf.transpose(images, [0, 2, 3, 1])
238
+ scale = 255 / (drange[1] - drange[0])
239
+ images = images * scale + (0.5 - drange[0] * scale)
240
+ return tf.saturate_cast(images, tf.uint8)
models/stylegan/stylegan_tf/dnnlib/util.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Miscellaneous utility classes and functions."""
9
+
10
+ import ctypes
11
+ import fnmatch
12
+ import importlib
13
+ import inspect
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import types
19
+ import io
20
+ import pickle
21
+ import re
22
+ import requests
23
+ import html
24
+ import hashlib
25
+ import glob
26
+ import uuid
27
+
28
+ from distutils.util import strtobool
29
+ from typing import Any, List, Tuple, Union
30
+
31
+
32
+ # Util classes
33
+ # ------------------------------------------------------------------------------------------
34
+
35
+
36
+ class EasyDict(dict):
37
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
38
+
39
+ def __getattr__(self, name: str) -> Any:
40
+ try:
41
+ return self[name]
42
+ except KeyError:
43
+ raise AttributeError(name)
44
+
45
+ def __setattr__(self, name: str, value: Any) -> None:
46
+ self[name] = value
47
+
48
+ def __delattr__(self, name: str) -> None:
49
+ del self[name]
50
+
51
+
52
+ class Logger(object):
53
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
54
+
55
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
56
+ self.file = None
57
+
58
+ if file_name is not None:
59
+ self.file = open(file_name, file_mode)
60
+
61
+ self.should_flush = should_flush
62
+ self.stdout = sys.stdout
63
+ self.stderr = sys.stderr
64
+
65
+ sys.stdout = self
66
+ sys.stderr = self
67
+
68
+ def __enter__(self) -> "Logger":
69
+ return self
70
+
71
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
72
+ self.close()
73
+
74
+ def write(self, text: str) -> None:
75
+ """Write text to stdout (and a file) and optionally flush."""
76
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
77
+ return
78
+
79
+ if self.file is not None:
80
+ self.file.write(text)
81
+
82
+ self.stdout.write(text)
83
+
84
+ if self.should_flush:
85
+ self.flush()
86
+
87
+ def flush(self) -> None:
88
+ """Flush written text to both stdout and a file, if open."""
89
+ if self.file is not None:
90
+ self.file.flush()
91
+
92
+ self.stdout.flush()
93
+
94
+ def close(self) -> None:
95
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
96
+ self.flush()
97
+
98
+ # if using multiple loggers, prevent closing in wrong order
99
+ if sys.stdout is self:
100
+ sys.stdout = self.stdout
101
+ if sys.stderr is self:
102
+ sys.stderr = self.stderr
103
+
104
+ if self.file is not None:
105
+ self.file.close()
106
+
107
+
108
+ # Small util functions
109
+ # ------------------------------------------------------------------------------------------
110
+
111
+
112
+ def format_time(seconds: Union[int, float]) -> str:
113
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
114
+ s = int(np.rint(seconds))
115
+
116
+ if s < 60:
117
+ return "{0}s".format(s)
118
+ elif s < 60 * 60:
119
+ return "{0}m {1:02}s".format(s // 60, s % 60)
120
+ elif s < 24 * 60 * 60:
121
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
122
+ else:
123
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
124
+
125
+
126
+ def ask_yes_no(question: str) -> bool:
127
+ """Ask the user the question until the user inputs a valid answer."""
128
+ while True:
129
+ try:
130
+ print("{0} [y/n]".format(question))
131
+ return strtobool(input().lower())
132
+ except ValueError:
133
+ pass
134
+
135
+
136
+ def tuple_product(t: Tuple) -> Any:
137
+ """Calculate the product of the tuple elements."""
138
+ result = 1
139
+
140
+ for v in t:
141
+ result *= v
142
+
143
+ return result
144
+
145
+
146
+ _str_to_ctype = {
147
+ "uint8": ctypes.c_ubyte,
148
+ "uint16": ctypes.c_uint16,
149
+ "uint32": ctypes.c_uint32,
150
+ "uint64": ctypes.c_uint64,
151
+ "int8": ctypes.c_byte,
152
+ "int16": ctypes.c_int16,
153
+ "int32": ctypes.c_int32,
154
+ "int64": ctypes.c_int64,
155
+ "float32": ctypes.c_float,
156
+ "float64": ctypes.c_double
157
+ }
158
+
159
+
160
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
161
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
162
+ type_str = None
163
+
164
+ if isinstance(type_obj, str):
165
+ type_str = type_obj
166
+ elif hasattr(type_obj, "__name__"):
167
+ type_str = type_obj.__name__
168
+ elif hasattr(type_obj, "name"):
169
+ type_str = type_obj.name
170
+ else:
171
+ raise RuntimeError("Cannot infer type name from input")
172
+
173
+ assert type_str in _str_to_ctype.keys()
174
+
175
+ my_dtype = np.dtype(type_str)
176
+ my_ctype = _str_to_ctype[type_str]
177
+
178
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
179
+
180
+ return my_dtype, my_ctype
181
+
182
+
183
+ def is_pickleable(obj: Any) -> bool:
184
+ try:
185
+ with io.BytesIO() as stream:
186
+ pickle.dump(obj, stream)
187
+ return True
188
+ except:
189
+ return False
190
+
191
+
192
+ # Functionality to import modules/objects by name, and call functions by name
193
+ # ------------------------------------------------------------------------------------------
194
+
195
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
196
+ """Searches for the underlying module behind the name to some python object.
197
+ Returns the module and the object name (original name with module part removed)."""
198
+
199
+ # allow convenience shorthands, substitute them by full names
200
+ obj_name = re.sub("^np.", "numpy.", obj_name)
201
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
202
+
203
+ # list alternatives for (module_name, local_obj_name)
204
+ parts = obj_name.split(".")
205
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
206
+
207
+ # try each alternative in turn
208
+ for module_name, local_obj_name in name_pairs:
209
+ try:
210
+ module = importlib.import_module(module_name) # may raise ImportError
211
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
212
+ return module, local_obj_name
213
+ except:
214
+ pass
215
+
216
+ # maybe some of the modules themselves contain errors?
217
+ for module_name, _local_obj_name in name_pairs:
218
+ try:
219
+ importlib.import_module(module_name) # may raise ImportError
220
+ except ImportError:
221
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
222
+ raise
223
+
224
+ # maybe the requested attribute is missing?
225
+ for module_name, local_obj_name in name_pairs:
226
+ try:
227
+ module = importlib.import_module(module_name) # may raise ImportError
228
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
229
+ except ImportError:
230
+ pass
231
+
232
+ # we are out of luck, but we have no idea why
233
+ raise ImportError(obj_name)
234
+
235
+
236
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
237
+ """Traverses the object name and returns the last (rightmost) python object."""
238
+ if obj_name == '':
239
+ return module
240
+ obj = module
241
+ for part in obj_name.split("."):
242
+ obj = getattr(obj, part)
243
+ return obj
244
+
245
+
246
+ def get_obj_by_name(name: str) -> Any:
247
+ """Finds the python object with the given name."""
248
+ module, obj_name = get_module_from_obj_name(name)
249
+ return get_obj_from_module(module, obj_name)
250
+
251
+
252
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
253
+ """Finds the python object with the given name and calls it as a function."""
254
+ assert func_name is not None
255
+ func_obj = get_obj_by_name(func_name)
256
+ assert callable(func_obj)
257
+ return func_obj(*args, **kwargs)
258
+
259
+
260
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
261
+ """Get the directory path of the module containing the given object name."""
262
+ module, _ = get_module_from_obj_name(obj_name)
263
+ return os.path.dirname(inspect.getfile(module))
264
+
265
+
266
+ def is_top_level_function(obj: Any) -> bool:
267
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
268
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
269
+
270
+
271
+ def get_top_level_function_name(obj: Any) -> str:
272
+ """Return the fully-qualified name of a top-level function."""
273
+ assert is_top_level_function(obj)
274
+ return obj.__module__ + "." + obj.__name__
275
+
276
+
277
+ # File system helpers
278
+ # ------------------------------------------------------------------------------------------
279
+
280
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
281
+ """List all files recursively in a given directory while ignoring given file and directory names.
282
+ Returns list of tuples containing both absolute and relative paths."""
283
+ assert os.path.isdir(dir_path)
284
+ base_name = os.path.basename(os.path.normpath(dir_path))
285
+
286
+ if ignores is None:
287
+ ignores = []
288
+
289
+ result = []
290
+
291
+ for root, dirs, files in os.walk(dir_path, topdown=True):
292
+ for ignore_ in ignores:
293
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
294
+
295
+ # dirs need to be edited in-place
296
+ for d in dirs_to_remove:
297
+ dirs.remove(d)
298
+
299
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
300
+
301
+ absolute_paths = [os.path.join(root, f) for f in files]
302
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
303
+
304
+ if add_base_to_relative:
305
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
306
+
307
+ assert len(absolute_paths) == len(relative_paths)
308
+ result += zip(absolute_paths, relative_paths)
309
+
310
+ return result
311
+
312
+
313
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
314
+ """Takes in a list of tuples of (src, dst) paths and copies files.
315
+ Will create all necessary directories."""
316
+ for file in files:
317
+ target_dir_name = os.path.dirname(file[1])
318
+
319
+ # will create all intermediate-level directories
320
+ if not os.path.exists(target_dir_name):
321
+ os.makedirs(target_dir_name)
322
+
323
+ shutil.copyfile(file[0], file[1])
324
+
325
+
326
+ # URL helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def is_url(obj: Any) -> bool:
330
+ """Determine whether the given object is a valid URL string."""
331
+ if not isinstance(obj, str) or not "://" in obj:
332
+ return False
333
+ try:
334
+ res = requests.compat.urlparse(obj)
335
+ if not res.scheme or not res.netloc or not "." in res.netloc:
336
+ return False
337
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
338
+ if not res.scheme or not res.netloc or not "." in res.netloc:
339
+ return False
340
+ except:
341
+ return False
342
+ return True
343
+
344
+
345
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
346
+ """Download the given URL and return a binary-mode file object to access the data."""
347
+ assert is_url(url)
348
+ assert num_attempts >= 1
349
+
350
+ # Lookup from cache.
351
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
352
+ if cache_dir is not None:
353
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
354
+ if len(cache_files) == 1:
355
+ return open(cache_files[0], "rb")
356
+
357
+ # Download.
358
+ url_name = None
359
+ url_data = None
360
+ with requests.Session() as session:
361
+ if verbose:
362
+ print("Downloading %s ..." % url, end="", flush=True)
363
+ for attempts_left in reversed(range(num_attempts)):
364
+ try:
365
+ with session.get(url) as res:
366
+ res.raise_for_status()
367
+ if len(res.content) == 0:
368
+ raise IOError("No data received")
369
+
370
+ if len(res.content) < 8192:
371
+ content_str = res.content.decode("utf-8")
372
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
373
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
374
+ if len(links) == 1:
375
+ url = requests.compat.urljoin(url, links[0])
376
+ raise IOError("Google Drive virus checker nag")
377
+ if "Google Drive - Quota exceeded" in content_str:
378
+ raise IOError("Google Drive quota exceeded")
379
+
380
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
381
+ url_name = match[1] if match else url
382
+ url_data = res.content
383
+ if verbose:
384
+ print(" done")
385
+ break
386
+ except:
387
+ if not attempts_left:
388
+ if verbose:
389
+ print(" failed")
390
+ raise
391
+ if verbose:
392
+ print(".", end="", flush=True)
393
+
394
+ # Save to cache.
395
+ if cache_dir is not None:
396
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
397
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
398
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
399
+ os.makedirs(cache_dir, exist_ok=True)
400
+ with open(temp_file, "wb") as f:
401
+ f.write(url_data)
402
+ os.replace(temp_file, cache_file) # atomic
403
+
404
+ # Return data as file object.
405
+ return io.BytesIO(url_data)
models/stylegan/stylegan_tf/generate_figures.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators."""
9
+
10
+ import os
11
+ import pickle
12
+ import numpy as np
13
+ import PIL.Image
14
+ import dnnlib
15
+ import dnnlib.tflib as tflib
16
+ import config
17
+
18
+ #----------------------------------------------------------------------------
19
+ # Helpers for loading and using pre-trained generators.
20
+
21
+ url_ffhq = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
22
+ url_celebahq = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl
23
+ url_bedrooms = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl
24
+ url_cars = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl
25
+ url_cats = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl
26
+
27
+ synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
28
+
29
+ _Gs_cache = dict()
30
+
31
+ def load_Gs(url):
32
+ if url not in _Gs_cache:
33
+ with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
34
+ _G, _D, Gs = pickle.load(f)
35
+ _Gs_cache[url] = Gs
36
+ return _Gs_cache[url]
37
+
38
+ #----------------------------------------------------------------------------
39
+ # Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images.
40
+
41
+ def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed):
42
+ print(png)
43
+ latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1])
44
+ images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb]
45
+
46
+ canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')
47
+ image_iter = iter(list(images))
48
+ for col, lod in enumerate(lods):
49
+ for row in range(rows * 2**lod):
50
+ image = PIL.Image.fromarray(next(image_iter), 'RGB')
51
+ image = image.crop((cx, cy, cx + cw, cy + ch))
52
+ image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)
53
+ canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))
54
+ canvas.save(png)
55
+
56
+ #----------------------------------------------------------------------------
57
+ # Figure 3: Style mixing.
58
+
59
+ def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):
60
+ print(png)
61
+ src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
62
+ dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)
63
+ src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
64
+ dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]
65
+ src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
66
+ dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)
67
+
68
+ canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')
69
+ for col, src_image in enumerate(list(src_images)):
70
+ canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
71
+ for row, dst_image in enumerate(list(dst_images)):
72
+ canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
73
+ row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
74
+ row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
75
+ row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
76
+ for col, image in enumerate(list(row_images)):
77
+ canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
78
+ canvas.save(png)
79
+
80
+ #----------------------------------------------------------------------------
81
+ # Figure 4: Noise detail.
82
+
83
+ def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds):
84
+ print(png)
85
+ canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white')
86
+ for row, seed in enumerate(seeds):
87
+ latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples)
88
+ images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs)
89
+ canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h))
90
+ for i in range(4):
91
+ crop = PIL.Image.fromarray(images[i + 1], 'RGB')
92
+ crop = crop.crop((650, 180, 906, 436))
93
+ crop = crop.resize((w//2, h//2), PIL.Image.NEAREST)
94
+ canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2))
95
+ diff = np.std(np.mean(images, axis=3), axis=0) * 4
96
+ diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8)
97
+ canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h))
98
+ canvas.save(png)
99
+
100
+ #----------------------------------------------------------------------------
101
+ # Figure 5: Noise components.
102
+
103
+ def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips):
104
+ print(png)
105
+ Gsc = Gs.clone()
106
+ noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')]
107
+ noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...]
108
+ latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
109
+ all_images = []
110
+ for noise_range in noise_ranges:
111
+ tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)})
112
+ range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs)
113
+ range_images[flips, :, :] = range_images[flips, :, ::-1]
114
+ all_images.append(list(range_images))
115
+
116
+ canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white')
117
+ for col, col_images in enumerate(zip(*all_images)):
118
+ canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0))
119
+ canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0))
120
+ canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h))
121
+ canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h))
122
+ canvas.save(png)
123
+
124
+ #----------------------------------------------------------------------------
125
+ # Figure 8: Truncation trick.
126
+
127
+ def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis):
128
+ print(png)
129
+ latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
130
+ dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component]
131
+ dlatent_avg = Gs.get_var('dlatent_avg') # [component]
132
+
133
+ canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white')
134
+ for row, dlatent in enumerate(list(dlatents)):
135
+ row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg
136
+ row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
137
+ for col, image in enumerate(list(row_images)):
138
+ canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h))
139
+ canvas.save(png)
140
+
141
+ #----------------------------------------------------------------------------
142
+ # Main program.
143
+
144
+ def main():
145
+ tflib.init_tf()
146
+ os.makedirs(config.result_dir, exist_ok=True)
147
+ draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5)
148
+ draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)])
149
+ draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012])
150
+ draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1])
151
+ draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1])
152
+ draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0)
153
+ draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2)
154
+ draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1)
155
+
156
+ #----------------------------------------------------------------------------
157
+
158
+ if __name__ == "__main__":
159
+ main()
160
+
161
+ #----------------------------------------------------------------------------
models/stylegan/stylegan_tf/metrics/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ # empty
models/stylegan/stylegan_tf/metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Frechet Inception Distance (FID)."""
9
+
10
+ import os
11
+ import numpy as np
12
+ import scipy
13
+ import tensorflow as tf
14
+ import dnnlib.tflib as tflib
15
+
16
+ from metrics import metric_base
17
+ from training import misc
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ class FID(metric_base.MetricBase):
22
+ def __init__(self, num_images, minibatch_per_gpu, **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.num_images = num_images
25
+ self.minibatch_per_gpu = minibatch_per_gpu
26
+
27
+ def _evaluate(self, Gs, num_gpus):
28
+ minibatch_size = num_gpus * self.minibatch_per_gpu
29
+ inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl
30
+ activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
31
+
32
+ # Calculate statistics for reals.
33
+ cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
34
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
35
+ if os.path.isfile(cache_file):
36
+ mu_real, sigma_real = misc.load_pkl(cache_file)
37
+ else:
38
+ for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)):
39
+ begin = idx * minibatch_size
40
+ end = min(begin + minibatch_size, self.num_images)
41
+ activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)
42
+ if end == self.num_images:
43
+ break
44
+ mu_real = np.mean(activations, axis=0)
45
+ sigma_real = np.cov(activations, rowvar=False)
46
+ misc.save_pkl((mu_real, sigma_real), cache_file)
47
+
48
+ # Construct TensorFlow graph.
49
+ result_expr = []
50
+ for gpu_idx in range(num_gpus):
51
+ with tf.device('/gpu:%d' % gpu_idx):
52
+ Gs_clone = Gs.clone()
53
+ inception_clone = inception.clone()
54
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
55
+ images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
56
+ images = tflib.convert_images_to_uint8(images)
57
+ result_expr.append(inception_clone.get_output_for(images))
58
+
59
+ # Calculate statistics for fakes.
60
+ for begin in range(0, self.num_images, minibatch_size):
61
+ end = min(begin + minibatch_size, self.num_images)
62
+ activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
63
+ mu_fake = np.mean(activations, axis=0)
64
+ sigma_fake = np.cov(activations, rowvar=False)
65
+
66
+ # Calculate FID.
67
+ m = np.square(mu_fake - mu_real).sum()
68
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
69
+ dist = m + np.trace(sigma_fake + sigma_real - 2*s)
70
+ self._report_result(np.real(dist))
71
+
72
+ #----------------------------------------------------------------------------
models/stylegan/stylegan_tf/metrics/linear_separability.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Linear Separability (LS)."""
9
+
10
+ from collections import defaultdict
11
+ import numpy as np
12
+ import sklearn.svm
13
+ import tensorflow as tf
14
+ import dnnlib.tflib as tflib
15
+
16
+ from metrics import metric_base
17
+ from training import misc
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ classifier_urls = [
22
+ 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl
23
+ 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl
24
+ 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl
25
+ 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl
26
+ 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl
27
+ 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl
28
+ 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl
29
+ 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl
30
+ 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl
31
+ 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl
32
+ 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl
33
+ 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl
34
+ 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl
35
+ 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl
36
+ 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl
37
+ 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl
38
+ 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl
39
+ 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl
40
+ 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl
41
+ 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl
42
+ 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl
43
+ 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl
44
+ 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl
45
+ 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl
46
+ 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl
47
+ 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl
48
+ 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl
49
+ 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl
50
+ 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl
51
+ 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl
52
+ 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl
53
+ 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl
54
+ 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl
55
+ 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl
56
+ 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl
57
+ 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl
58
+ 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl
59
+ 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl
60
+ 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl
61
+ 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl
62
+ ]
63
+
64
+ #----------------------------------------------------------------------------
65
+
66
+ def prob_normalize(p):
67
+ p = np.asarray(p).astype(np.float32)
68
+ assert len(p.shape) == 2
69
+ return p / np.sum(p)
70
+
71
+ def mutual_information(p):
72
+ p = prob_normalize(p)
73
+ px = np.sum(p, axis=1)
74
+ py = np.sum(p, axis=0)
75
+ result = 0.0
76
+ for x in range(p.shape[0]):
77
+ p_x = px[x]
78
+ for y in range(p.shape[1]):
79
+ p_xy = p[x][y]
80
+ p_y = py[y]
81
+ if p_xy > 0.0:
82
+ result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output
83
+ return result
84
+
85
+ def entropy(p):
86
+ p = prob_normalize(p)
87
+ result = 0.0
88
+ for x in range(p.shape[0]):
89
+ for y in range(p.shape[1]):
90
+ p_xy = p[x][y]
91
+ if p_xy > 0.0:
92
+ result -= p_xy * np.log2(p_xy)
93
+ return result
94
+
95
+ def conditional_entropy(p):
96
+ # H(Y|X) where X corresponds to axis 0, Y to axis 1
97
+ # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0?
98
+ p = prob_normalize(p)
99
+ y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y)
100
+ return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up.
101
+
102
+ #----------------------------------------------------------------------------
103
+
104
+ class LS(metric_base.MetricBase):
105
+ def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs):
106
+ assert num_keep <= num_samples
107
+ super().__init__(**kwargs)
108
+ self.num_samples = num_samples
109
+ self.num_keep = num_keep
110
+ self.attrib_indices = attrib_indices
111
+ self.minibatch_per_gpu = minibatch_per_gpu
112
+
113
+ def _evaluate(self, Gs, num_gpus):
114
+ minibatch_size = num_gpus * self.minibatch_per_gpu
115
+
116
+ # Construct TensorFlow graph for each GPU.
117
+ result_expr = []
118
+ for gpu_idx in range(num_gpus):
119
+ with tf.device('/gpu:%d' % gpu_idx):
120
+ Gs_clone = Gs.clone()
121
+
122
+ # Generate images.
123
+ latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
124
+ dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True)
125
+ images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True)
126
+
127
+ # Downsample to 256x256. The attribute classifiers were built for 256x256.
128
+ if images.shape[2] > 256:
129
+ factor = images.shape[2] // 256
130
+ images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
131
+ images = tf.reduce_mean(images, axis=[3, 5])
132
+
133
+ # Run classifier for each attribute.
134
+ result_dict = dict(latents=latents, dlatents=dlatents[:,-1])
135
+ for attrib_idx in self.attrib_indices:
136
+ classifier = misc.load_pkl(classifier_urls[attrib_idx])
137
+ logits = classifier.get_output_for(images, None)
138
+ predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1))
139
+ result_dict[attrib_idx] = predictions
140
+ result_expr.append(result_dict)
141
+
142
+ # Sampling loop.
143
+ results = []
144
+ for _ in range(0, self.num_samples, minibatch_size):
145
+ results += tflib.run(result_expr)
146
+ results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()}
147
+
148
+ # Calculate conditional entropy for each attribute.
149
+ conditional_entropies = defaultdict(list)
150
+ for attrib_idx in self.attrib_indices:
151
+ # Prune the least confident samples.
152
+ pruned_indices = list(range(self.num_samples))
153
+ pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i]))
154
+ pruned_indices = pruned_indices[:self.num_keep]
155
+
156
+ # Fit SVM to the remaining samples.
157
+ svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1)
158
+ for space in ['latents', 'dlatents']:
159
+ svm_inputs = results[space][pruned_indices]
160
+ try:
161
+ svm = sklearn.svm.LinearSVC()
162
+ svm.fit(svm_inputs, svm_targets)
163
+ svm.score(svm_inputs, svm_targets)
164
+ svm_outputs = svm.predict(svm_inputs)
165
+ except:
166
+ svm_outputs = svm_targets # assume perfect prediction
167
+
168
+ # Calculate conditional entropy.
169
+ p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)]
170
+ conditional_entropies[space].append(conditional_entropy(p))
171
+
172
+ # Calculate separability scores.
173
+ scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()}
174
+ self._report_result(scores['latents'], suffix='_z')
175
+ self._report_result(scores['dlatents'], suffix='_w')
176
+
177
+ #----------------------------------------------------------------------------
models/stylegan/stylegan_tf/metrics/metric_base.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Common definitions for GAN metrics."""
9
+
10
+ import os
11
+ import time
12
+ import hashlib
13
+ import numpy as np
14
+ import tensorflow as tf
15
+ import dnnlib
16
+ import dnnlib.tflib as tflib
17
+
18
+ import config
19
+ from training import misc
20
+ from training import dataset
21
+
22
+ #----------------------------------------------------------------------------
23
+ # Standard metrics.
24
+
25
+ fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8)
26
+ ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16)
27
+ ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16)
28
+ ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16)
29
+ ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16)
30
+ ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4)
31
+ dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging
32
+
33
+ #----------------------------------------------------------------------------
34
+ # Base class for metrics.
35
+
36
+ class MetricBase:
37
+ def __init__(self, name):
38
+ self.name = name
39
+ self._network_pkl = None
40
+ self._dataset_args = None
41
+ self._mirror_augment = None
42
+ self._results = []
43
+ self._eval_time = None
44
+
45
+ def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True):
46
+ self._network_pkl = network_pkl
47
+ self._dataset_args = dataset_args
48
+ self._mirror_augment = mirror_augment
49
+ self._results = []
50
+
51
+ if (dataset_args is None or mirror_augment is None) and run_dir is not None:
52
+ run_config = misc.parse_config_for_previous_run(run_dir)
53
+ self._dataset_args = dict(run_config['dataset'])
54
+ self._dataset_args['shuffle_mb'] = 0
55
+ self._mirror_augment = run_config['train'].get('mirror_augment', False)
56
+
57
+ time_begin = time.time()
58
+ with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager
59
+ _G, _D, Gs = misc.load_pkl(self._network_pkl)
60
+ self._evaluate(Gs, num_gpus=num_gpus)
61
+ self._eval_time = time.time() - time_begin
62
+
63
+ if log_results:
64
+ result_str = self.get_result_str()
65
+ if run_dir is not None:
66
+ log = os.path.join(run_dir, 'metric-%s.txt' % self.name)
67
+ with dnnlib.util.Logger(log, 'a'):
68
+ print(result_str)
69
+ else:
70
+ print(result_str)
71
+
72
+ def get_result_str(self):
73
+ network_name = os.path.splitext(os.path.basename(self._network_pkl))[0]
74
+ if len(network_name) > 29:
75
+ network_name = '...' + network_name[-26:]
76
+ result_str = '%-30s' % network_name
77
+ result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time)
78
+ for res in self._results:
79
+ result_str += ' ' + self.name + res.suffix + ' '
80
+ result_str += res.fmt % res.value
81
+ return result_str
82
+
83
+ def update_autosummaries(self):
84
+ for res in self._results:
85
+ tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value)
86
+
87
+ def _evaluate(self, Gs, num_gpus):
88
+ raise NotImplementedError # to be overridden by subclasses
89
+
90
+ def _report_result(self, value, suffix='', fmt='%-10.4f'):
91
+ self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
92
+
93
+ def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
94
+ all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment)
95
+ all_args.update(self._dataset_args)
96
+ all_args.update(kwargs)
97
+ md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))
98
+ dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1]
99
+ return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension))
100
+
101
+ def _iterate_reals(self, minibatch_size):
102
+ dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args)
103
+ while True:
104
+ images, _labels = dataset_obj.get_minibatch_np(minibatch_size)
105
+ if self._mirror_augment:
106
+ images = misc.apply_mirror_augment(images)
107
+ yield images
108
+
109
+ def _iterate_fakes(self, Gs, minibatch_size, num_gpus):
110
+ while True:
111
+ latents = np.random.randn(minibatch_size, *Gs.input_shape[1:])
112
+ fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
113
+ images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True)
114
+ yield images
115
+
116
+ #----------------------------------------------------------------------------
117
+ # Group of multiple metrics.
118
+
119
+ class MetricGroup:
120
+ def __init__(self, metric_kwarg_list):
121
+ self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list]
122
+
123
+ def run(self, *args, **kwargs):
124
+ for metric in self.metrics:
125
+ metric.run(*args, **kwargs)
126
+
127
+ def get_result_str(self):
128
+ return ' '.join(metric.get_result_str() for metric in self.metrics)
129
+
130
+ def update_autosummaries(self):
131
+ for metric in self.metrics:
132
+ metric.update_autosummaries()
133
+
134
+ #----------------------------------------------------------------------------
135
+ # Dummy metric for debugging purposes.
136
+
137
+ class DummyMetric(MetricBase):
138
+ def _evaluate(self, Gs, num_gpus):
139
+ _ = Gs, num_gpus
140
+ self._report_result(0.0)
141
+
142
+ #----------------------------------------------------------------------------
models/stylegan/stylegan_tf/metrics/perceptual_path_length.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Perceptual Path Length (PPL)."""
9
+
10
+ import numpy as np
11
+ import tensorflow as tf
12
+ import dnnlib.tflib as tflib
13
+
14
+ from metrics import metric_base
15
+ from training import misc
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ # Normalize batch of vectors.
20
+ def normalize(v):
21
+ return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
22
+
23
+ # Spherical interpolation of a batch of vectors.
24
+ def slerp(a, b, t):
25
+ a = normalize(a)
26
+ b = normalize(b)
27
+ d = tf.reduce_sum(a * b, axis=-1, keepdims=True)
28
+ p = t * tf.math.acos(d)
29
+ c = normalize(b - d * a)
30
+ d = a * tf.math.cos(p) + c * tf.math.sin(p)
31
+ return normalize(d)
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ class PPL(metric_base.MetricBase):
36
+ def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs):
37
+ assert space in ['z', 'w']
38
+ assert sampling in ['full', 'end']
39
+ super().__init__(**kwargs)
40
+ self.num_samples = num_samples
41
+ self.epsilon = epsilon
42
+ self.space = space
43
+ self.sampling = sampling
44
+ self.minibatch_per_gpu = minibatch_per_gpu
45
+
46
+ def _evaluate(self, Gs, num_gpus):
47
+ minibatch_size = num_gpus * self.minibatch_per_gpu
48
+
49
+ # Construct TensorFlow graph.
50
+ distance_expr = []
51
+ for gpu_idx in range(num_gpus):
52
+ with tf.device('/gpu:%d' % gpu_idx):
53
+ Gs_clone = Gs.clone()
54
+ noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')]
55
+
56
+ # Generate random latents and interpolation t-values.
57
+ lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:])
58
+ lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0)
59
+
60
+ # Interpolate in W or Z.
61
+ if self.space == 'w':
62
+ dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True)
63
+ dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2]
64
+ dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis])
65
+ dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon)
66
+ dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape)
67
+ else: # space == 'z'
68
+ lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2]
69
+ lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis])
70
+ lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon)
71
+ lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape)
72
+ dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True)
73
+
74
+ # Synthesize images.
75
+ with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch
76
+ images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False)
77
+
78
+ # Crop only the face region.
79
+ c = int(images.shape[2] // 8)
80
+ images = images[:, :, c*3 : c*7, c*2 : c*6]
81
+
82
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
83
+ if images.shape[2] > 256:
84
+ factor = images.shape[2] // 256
85
+ images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
86
+ images = tf.reduce_mean(images, axis=[3,5])
87
+
88
+ # Scale dynamic range from [-1,1] to [0,255] for VGG.
89
+ images = (images + 1) * (255 / 2)
90
+
91
+ # Evaluate perceptual distance.
92
+ img_e0, img_e1 = images[0::2], images[1::2]
93
+ distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl
94
+ distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2))
95
+
96
+ # Sampling loop.
97
+ all_distances = []
98
+ for _ in range(0, self.num_samples, minibatch_size):
99
+ all_distances += tflib.run(distance_expr)
100
+ all_distances = np.concatenate(all_distances, axis=0)
101
+
102
+ # Reject outliers.
103
+ lo = np.percentile(all_distances, 1, interpolation='lower')
104
+ hi = np.percentile(all_distances, 99, interpolation='higher')
105
+ filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances)
106
+ self._report_result(np.mean(filtered_distances))
107
+
108
+ #----------------------------------------------------------------------------
models/stylegan/stylegan_tf/pretrained_example.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Minimal script for generating an image using pre-trained StyleGAN generator."""
9
+
10
+ import os
11
+ import pickle
12
+ import numpy as np
13
+ import PIL.Image
14
+ import dnnlib
15
+ import dnnlib.tflib as tflib
16
+ import config
17
+
18
+ def main():
19
+ # Initialize TensorFlow.
20
+ tflib.init_tf()
21
+
22
+ # Load pre-trained network.
23
+ url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
24
+ with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
25
+ _G, _D, Gs = pickle.load(f)
26
+ # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
27
+ # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
28
+ # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
29
+
30
+ # Print network details.
31
+ Gs.print_layers()
32
+
33
+ # Pick latent vector.
34
+ rnd = np.random.RandomState(5)
35
+ latents = rnd.randn(1, Gs.input_shape[1])
36
+
37
+ # Generate image.
38
+ fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
39
+ images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
40
+
41
+ # Save image.
42
+ os.makedirs(config.result_dir, exist_ok=True)
43
+ png_filename = os.path.join(config.result_dir, 'example.png')
44
+ PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
45
+
46
+ if __name__ == "__main__":
47
+ main()
models/stylegan/stylegan_tf/run_metrics.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Main entry point for training StyleGAN and ProGAN networks."""
9
+
10
+ import dnnlib
11
+ from dnnlib import EasyDict
12
+ import dnnlib.tflib as tflib
13
+
14
+ import config
15
+ from metrics import metric_base
16
+ from training import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment):
21
+ ctx = dnnlib.RunContext(submit_config)
22
+ tflib.init_tf()
23
+ print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl))
24
+ metric = dnnlib.util.call_func_by_name(**metric_args)
25
+ print()
26
+ metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus)
27
+ print()
28
+ ctx.close()
29
+
30
+ #----------------------------------------------------------------------------
31
+
32
+ def run_snapshot(submit_config, metric_args, run_id, snapshot):
33
+ ctx = dnnlib.RunContext(submit_config)
34
+ tflib.init_tf()
35
+ print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot))
36
+ run_dir = misc.locate_run_dir(run_id)
37
+ network_pkl = misc.locate_network_pkl(run_dir, snapshot)
38
+ metric = dnnlib.util.call_func_by_name(**metric_args)
39
+ print()
40
+ metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)
41
+ print()
42
+ ctx.close()
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ def run_all_snapshots(submit_config, metric_args, run_id):
47
+ ctx = dnnlib.RunContext(submit_config)
48
+ tflib.init_tf()
49
+ print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id))
50
+ run_dir = misc.locate_run_dir(run_id)
51
+ network_pkls = misc.list_network_pkls(run_dir)
52
+ metric = dnnlib.util.call_func_by_name(**metric_args)
53
+ print()
54
+ for idx, network_pkl in enumerate(network_pkls):
55
+ ctx.update('', idx, len(network_pkls))
56
+ metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)
57
+ print()
58
+ ctx.close()
59
+
60
+ #----------------------------------------------------------------------------
61
+
62
+ def main():
63
+ submit_config = dnnlib.SubmitConfig()
64
+
65
+ # Which metrics to evaluate?
66
+ metrics = []
67
+ metrics += [metric_base.fid50k]
68
+ #metrics += [metric_base.ppl_zfull]
69
+ #metrics += [metric_base.ppl_wfull]
70
+ #metrics += [metric_base.ppl_zend]
71
+ #metrics += [metric_base.ppl_wend]
72
+ #metrics += [metric_base.ls]
73
+ #metrics += [metric_base.dummy]
74
+
75
+ # Which networks to evaluate them on?
76
+ tasks = []
77
+ tasks += [EasyDict(run_func_name='run_metrics.run_pickle', network_pkl='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', dataset_args=EasyDict(tfrecord_dir='ffhq', shuffle_mb=0), mirror_augment=True)] # karras2019stylegan-ffhq-1024x1024.pkl
78
+ #tasks += [EasyDict(run_func_name='run_metrics.run_snapshot', run_id=100, snapshot=25000)]
79
+ #tasks += [EasyDict(run_func_name='run_metrics.run_all_snapshots', run_id=100)]
80
+
81
+ # How many GPUs to use?
82
+ submit_config.num_gpus = 1
83
+ #submit_config.num_gpus = 2
84
+ #submit_config.num_gpus = 4
85
+ #submit_config.num_gpus = 8
86
+
87
+ # Execute.
88
+ submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir)
89
+ submit_config.run_dir_ignore += config.run_dir_ignore
90
+ for task in tasks:
91
+ for metric in metrics:
92
+ submit_config.run_desc = '%s-%s' % (task.run_func_name, metric.name)
93
+ if task.run_func_name.endswith('run_snapshot'):
94
+ submit_config.run_desc += '-%s-%s' % (task.run_id, task.snapshot)
95
+ if task.run_func_name.endswith('run_all_snapshots'):
96
+ submit_config.run_desc += '-%s' % task.run_id
97
+ submit_config.run_desc += '-%dgpu' % submit_config.num_gpus
98
+ dnnlib.submit_run(submit_config, metric_args=metric, **task)
99
+
100
+ #----------------------------------------------------------------------------
101
+
102
+ if __name__ == "__main__":
103
+ main()
104
+
105
+ #----------------------------------------------------------------------------