diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..eea938df748bd62dde8bdf5b9384d3dcd8675dad --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 Erik Härkönen + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index d74693852da735718f39d1aad452481ef6936f01..48d92436c16afb605cad08292b2ebe1528434e9c 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ --- -title: Guccio AI Designer -emoji: 📉 -colorFrom: green -colorTo: purple +title: ClothingGAN +emoji: 👘 +colorFrom: indigo +colorTo: gray sdk: gradio -sdk_version: 3.1.4 +sdk_version: 2.9.4 app_file: app.py pinned: false +license: cc-by-nc-3.0 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..cefb7d0a530ae0474a0725aed487e07fb4d020e4 --- /dev/null +++ b/app.py @@ -0,0 +1,288 @@ +from ssl import ALERT_DESCRIPTION_CLOSE_NOTIFY +import nltk; nltk.download('wordnet') + +#@title Load Model +selected_model = 'lookbook' + +# Load model +from IPython.utils import io +import torch +import PIL +import numpy as np +import ipywidgets as widgets +from PIL import Image +import imageio +from models import get_instrumented_model +from decomposition import get_or_compute +from config import Config +from skimage import img_as_ubyte +import gradio as gr +import numpy as np +from ipywidgets import fixed + +# Speed up computation +torch.autograd.set_grad_enabled(False) +torch.backends.cudnn.benchmark = True + +# Specify model to use +config = Config( + model='StyleGAN2', + layer='style', + output_class=selected_model, + components=80, + use_w=True, + batch_size=5_000, # style layer quite small +) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +inst = get_instrumented_model(config.model, config.output_class, + config.layer, torch.device(device), use_w=config.use_w) + +path_to_components = get_or_compute(config, inst) + +model = inst.model + +comps = np.load(path_to_components) +lst = comps.files +latent_dirs = [] +latent_stdevs = [] + +load_activations = False + +for item in lst: + if load_activations: + if item == 'act_comp': + for i in range(comps[item].shape[0]): + latent_dirs.append(comps[item][i]) + if item == 'act_stdev': + for i in range(comps[item].shape[0]): + latent_stdevs.append(comps[item][i]) + else: + if item == 'lat_comp': + for i in range(comps[item].shape[0]): + latent_dirs.append(comps[item][i]) + if item == 'lat_stdev': + for i in range(comps[item].shape[0]): + latent_stdevs.append(comps[item][i]) + + +#@title Define functions + + +# Taken from https://github.com/alexanderkuk/log-progress +def log_progress(sequence, every=1, size=None, name='Items'): + from ipywidgets import IntProgress, HTML, VBox + from IPython.display import display + + is_iterator = False + if size is None: + try: + size = len(sequence) + except TypeError: + is_iterator = True + if size is not None: + if every is None: + if size <= 200: + every = 1 + else: + every = int(size / 200) # every 0.5% + else: + assert every is not None, 'sequence is iterator, set every' + + if is_iterator: + progress = IntProgress(min=0, max=1, value=1) + progress.bar_style = 'info' + else: + progress = IntProgress(min=0, max=size, value=0) + label = HTML() + box = VBox(children=[label, progress]) + display(box) + + index = 0 + try: + for index, record in enumerate(sequence, 1): + if index == 1 or index % every == 0: + if is_iterator: + label.value = '{name}: {index} / ?'.format( + name=name, + index=index + ) + else: + progress.value = index + label.value = u'{name}: {index} / {size}'.format( + name=name, + index=index, + size=size + ) + yield record + except: + progress.bar_style = 'danger' + raise + else: + progress.bar_style = 'success' + progress.value = index + label.value = "{name}: {index}".format( + name=name, + index=str(index or '?') + ) + +def name_direction(sender): + if not text.value: + print('Please name the direction before saving') + return + + if num in named_directions.values(): + target_key = list(named_directions.keys())[list(named_directions.values()).index(num)] + print(f'Direction already named: {target_key}') + print(f'Overwriting... ') + del(named_directions[target_key]) + named_directions[text.value] = [num, start_layer.value, end_layer.value] + save_direction(random_dir, text.value) + for item in named_directions: + print(item, named_directions[item]) + +def save_direction(direction, filename): + filename += ".npy" + np.save(filename, direction, allow_pickle=True, fix_imports=True) + print(f'Latent direction saved as {filename}') + +def mix_w(w1, w2, content, style): + for i in range(0,5): + w2[i] = w1[i] * (1 - content) + w2[i] * content + + for i in range(5, 16): + w2[i] = w1[i] * (1 - style) + w2[i] * style + + return w2 + +def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None): + # blockPrint() + model.truncation = truncation + if w is None: + w = model.sample_latent(1, seed=seed).detach().cpu().numpy() + w = [w]*model.get_max_latents() # one per layer + else: + w = [np.expand_dims(x, 0) for x in w] + + for l in range(start, end): + for i in range(len(directions)): + w[l] = w[l] + directions[i] * distances[i] * scale + + torch.cuda.empty_cache() + #save image and display + out = model.sample_np(w) + final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS) + + + if save is not None: + if disp == False: + print(save) + final_im.save(f'out/{seed}_{save:05}.png') + if disp: + display(final_im) + + return final_im + +def generate_mov(seed, truncation, direction_vec, scale, layers, n_frames, out_name = 'out', noise_spec = None, loop=True): + """Generates a mov moving back and forth along the chosen direction vector""" + # Example of reading a generated set of images, and storing as MP4. + movieName = f'{out_name}.mp4' + offset = -10 + step = 20 / n_frames + imgs = [] + for i in log_progress(range(n_frames), name = "Generating frames"): + print(f'\r{i} / {n_frames}', end='') + w = model.sample_latent(1, seed=seed).cpu().numpy() + + model.truncation = truncation + w = [w]*model.get_max_latents() # one per layer + for l in layers: + if l <= model.get_max_latents(): + w[l] = w[l] + direction_vec * offset * scale + + #save image and display + out = model.sample_np(w) + final_im = Image.fromarray((out * 255).astype(np.uint8)) + imgs.append(out) + #increase offset + offset += step + if loop: + imgs += imgs[::-1] + with imageio.get_writer(movieName, mode='I') as writer: + for image in log_progress(list(imgs), name = "Creating animation"): + writer.append_data(img_as_ubyte(image)) + + +#@title Demo UI + + +def generate_image(seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer): + seed1 = int(seed1) + seed2 = int(seed2) + + scale = 1 + params = {'c0': c0, + 'c1': c1, + 'c2': c2, + 'c3': c3, + 'c4': c4, + 'c5': c5, + 'c6': c6} + + param_indexes = {'c0': 0, + 'c1': 1, + 'c2': 2, + 'c3': 3, + 'c4': 4, + 'c5': 5, + 'c6': 6} + + directions = [] + distances = [] + for k, v in params.items(): + directions.append(latent_dirs[param_indexes[k]]) + distances.append(v) + + w1 = model.sample_latent(1, seed=seed1).detach().cpu().numpy() + w1 = [w1]*model.get_max_latents() # one per layer + im1 = model.sample_np(w1) + + w2 = model.sample_latent(1, seed=seed2).detach().cpu().numpy() + w2 = [w2]*model.get_max_latents() # one per layer + im2 = model.sample_np(w2) + combined_im = np.concatenate([im1, im2], axis=1) + input_im = Image.fromarray((combined_im * 255).astype(np.uint8)) + + + mixed_w = mix_w(w1, w2, content, style) + return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False) + +truncation = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label="Truncation") +start_layer = gr.inputs.Number(default=3, label="Start Layer") +end_layer = gr.inputs.Number(default=14, label="End Layer") +seed1 = gr.inputs.Number(default=0, label="Seed 1") +seed2 = gr.inputs.Number(default=0, label="Seed 2") +content = gr.inputs.Slider(label="Structure", minimum=0, maximum=1, default=0.5) +style = gr.inputs.Slider(label="Style", minimum=0, maximum=1, default=0.5) + +slider_max_val = 20 +slider_min_val = -20 +slider_step = 1 + +c0 = gr.inputs.Slider(label="Sleeve & Size", minimum=slider_min_val, maximum=slider_max_val, default=0) +c1 = gr.inputs.Slider(label="Dress - Jacket", minimum=slider_min_val, maximum=slider_max_val, default=0) +c2 = gr.inputs.Slider(label="Female Coat", minimum=slider_min_val, maximum=slider_max_val, default=0) +c3 = gr.inputs.Slider(label="Coat", minimum=slider_min_val, maximum=slider_max_val, default=0) +c4 = gr.inputs.Slider(label="Graphics", minimum=slider_min_val, maximum=slider_max_val, default=0) +c5 = gr.inputs.Slider(label="Dark", minimum=slider_min_val, maximum=slider_max_val, default=0) +c6 = gr.inputs.Slider(label="Less Cleavage", minimum=slider_min_val, maximum=slider_max_val, default=0) + + +scale = 1 + +inputs = [seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer] +description = "Change the seed number to generate different parent design.Please give a clap/star if you find it useful :)" + +article="

Made by @AdiNarendra with 🖤 .Thanks to @mfrashad for the inspiration for this.

" + +gr.Interface(generate_image, inputs, ["image", "image"], description=description, live=True,article=article,title="ClothingGAN").launch() \ No newline at end of file diff --git a/components/stylegan2-lookbook_style_ipca_c80_n300000_w.npz b/components/stylegan2-lookbook_style_ipca_c80_n300000_w.npz new file mode 100644 index 0000000000000000000000000000000000000000..a9d4e843c6193dd9cdc30c1ce27defc230201a89 --- /dev/null +++ b/components/stylegan2-lookbook_style_ipca_c80_n300000_w.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc8dd611f9eba549338aaac546bc11dfda01ced79456ee0bb63387adf997bde1 +size 312337 diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5af238a0a4382504bd2af894d30331e1be33079a --- /dev/null +++ b/config.py @@ -0,0 +1,72 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import sys +import argparse +import json +from copy import deepcopy + +class Config: + def __init__(self, **kwargs): + self.from_args([]) # set all defaults + self.default_args = deepcopy(self.__dict__) + self.from_dict(kwargs) # override + + def __str__(self): + custom = {} + default = {} + + # Find non-default arguments + for k, v in self.__dict__.items(): + if k == 'default_args': + continue + + in_default = k in self.default_args + same_value = self.default_args.get(k) == v + + if in_default and same_value: + default[k] = v + else: + custom[k] = v + + config = { + 'custom': custom, + 'default': default + } + + return json.dumps(config, indent=4) + + def __repr__(self): + return self.__str__() + + def from_dict(self, dictionary): + for k, v in dictionary.items(): + setattr(self, k, v) + return self + + def from_args(self, args=sys.argv[1:]): + parser = argparse.ArgumentParser(description='GAN component analysis config') + parser.add_argument('--model', dest='model', type=str, default='StyleGAN', help='The network to analyze') # StyleGAN, DCGAN, ProGAN, BigGAN-XYZ + parser.add_argument('--layer', dest='layer', type=str, default='g_mapping', help='The layer to analyze') + parser.add_argument('--class', dest='output_class', type=str, default=None, help='Output class to generate (BigGAN: Imagenet, ProGAN: LSUN)') + parser.add_argument('--est', dest='estimator', type=str, default='ipca', help='The algorithm to use [pca, fbpca, cupca, spca, ica]') + parser.add_argument('--sparsity', type=float, default=1.0, help='Sparsity parameter of SPCA') + parser.add_argument('--video', dest='make_video', action='store_true', help='Generate output videos (MP4s)') + parser.add_argument('--batch', dest='batch_mode', action='store_true', help="Don't open windows, instead save results to file") + parser.add_argument('-b', dest='batch_size', type=int, default=None, help='Minibatch size, leave empty for automatic detection') + parser.add_argument('-c', dest='components', type=int, default=80, help='Number of components to keep') + parser.add_argument('-n', type=int, default=300_000, help='Number of examples to use in decomposition') + parser.add_argument('--use_w', action='store_true', help='Use W latent space (StyleGAN(2))') + parser.add_argument('--sigma', type=float, default=2.0, help='Number of stdevs to walk in visualize.py') + parser.add_argument('--inputs', type=str, default=None, help='Path to directory with named components') + parser.add_argument('--seed', type=int, default=None, help='Seed used in decomposition') + args = parser.parse_args(args) + + return self.from_dict(args.__dict__) \ No newline at end of file diff --git a/decomposition.py b/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..4819e3324707f15c33fba6f35ab6abdc66dea919 --- /dev/null +++ b/decomposition.py @@ -0,0 +1,402 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Patch for broken CTRL+C handler +# https://github.com/ContinuumIO/anaconda-issues/issues/905 +import os +os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1' + +import numpy as np +import os +from pathlib import Path +import re +import sys +import datetime +import argparse +import torch +import json +from types import SimpleNamespace +import scipy +from scipy.cluster.vq import kmeans +from tqdm import trange +from netdissect.nethook import InstrumentedModel +from config import Config +from estimators import get_estimator +from models import get_instrumented_model + +SEED_SAMPLING = 1 +SEED_RANDOM_DIRS = 2 +SEED_LINREG = 3 +SEED_VISUALIZATION = 5 + +B = 20 +n_clusters = 500 + +def get_random_dirs(components, dimensions): + gen = np.random.RandomState(seed=SEED_RANDOM_DIRS) + dirs = gen.normal(size=(components, dimensions)) + dirs /= np.sqrt(np.sum(dirs**2, axis=1, keepdims=True)) + return dirs.astype(np.float32) + +# Compute maximum batch size for given VRAM and network +def get_max_batch_size(inst, device, layer_name=None): + inst.remove_edits() + + # Reset statistics + torch.cuda.reset_max_memory_cached(device) + torch.cuda.reset_max_memory_allocated(device) + total_mem = torch.cuda.get_device_properties(device).total_memory + + B_max = 20 + + # Measure actual usage + for i in range(2, B_max, 2): + z = inst.model.sample_latent(n_samples=i) + if layer_name: + inst.model.partial_forward(z, layer_name) + else: + inst.model.forward(z) + + maxmem = torch.cuda.max_memory_allocated(device) + del z + + if maxmem > 0.5*total_mem: + print('Batch size {:d}: memory usage {:.0f}MB'.format(i, maxmem / 1e6)) + return i + + return B_max + +# Solve for directions in latent space that match PCs in activaiton space +def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config): + print('Performing least squares regression', flush=True) + + torch.manual_seed(SEED_LINREG) + np.random.seed(SEED_LINREG) + + comp = torch.from_numpy(comp_np).float().to(inst.model.device) + mean = torch.from_numpy(mean_np).float().to(inst.model.device) + stdev = torch.from_numpy(stdev_np).float().to(inst.model.device) + + n_samp = max(10_000, config.n) // B * B # make divisible + n_comp = comp.shape[0] + latent_dims = inst.model.get_latent_dims() + + # We're looking for M s.t. M*P*G'(Z) = Z => M*A = Z + # Z = batch of latent vectors (n_samples x latent_dims) + # G'(Z) = batch of activations at intermediate layer + # A = P*G'(Z) = projected activations (n_samples x pca_coords) + # M = linear mapping (pca_coords x latent_dims) + + # Minimization min_M ||MA - Z||_l2 rewritten as min_M.T ||A.T*M.T - Z.T||_l2 + # to match format expected by pytorch.lstsq + + # TODO: regression on pixel-space outputs? (using nonlinear optimizer) + # min_M lpips(G_full(MA), G_full(Z)) + + # Tensors to fill with data + # Dimensions other way around, so these are actually the transposes + A = np.zeros((n_samp, n_comp), dtype=np.float32) + Z = np.zeros((n_samp, latent_dims), dtype=np.float32) + + # Project tensor X onto PCs, return coordinates + def project(X, comp): + N = X.shape[0] + K = comp.shape[0] + coords = torch.bmm(comp.expand([N]+[-1]*comp.ndim), X.view(N, -1, 1)) + return coords.reshape(N, K) + + for i in trange(n_samp // B, desc='Collecting samples', ascii=True): + z = inst.model.sample_latent(B) + inst.model.partial_forward(z, config.layer) + act = inst.retained_features()[config.layer].reshape(B, -1) + + # Project onto basis + act = act - mean + coords = project(act, comp) + coords_scaled = coords / stdev + + A[i*B:(i+1)*B] = coords_scaled.detach().cpu().numpy() + Z[i*B:(i+1)*B] = z.detach().cpu().numpy().reshape(B, -1) + + # Solve least squares fit + + # gelsd = divide-and-conquer SVD; good default + # gelsy = complete orthogonal factorization; sometimes faster + # gelss = SVD; slow but less memory hungry + M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] # torch.lstsq(Z, A)[0][:n_comp, :] + + # Solution given by rows of M_t + Z_comp = M_t[:n_comp, :] + Z_mean = np.mean(Z, axis=0, keepdims=True) + + return Z_comp, Z_mean + +def regression(comp, mean, stdev, inst, config): + # Sanity check: verify orthonormality + M = np.dot(comp, comp.T) + if not np.allclose(M, np.identity(M.shape[0])): + det = np.linalg.det(M) + print(f'WARNING: Computed basis is not orthonormal (determinant={det})') + + return linreg_lstsq(comp, mean, stdev, inst, config) + +def compute(config, dump_name, instrumented_model): + global B + + timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M") + print(f'[{timestamp()}] Computing', dump_name.name) + + # Ensure reproducibility + torch.manual_seed(0) # also sets cuda seeds + np.random.seed(0) + + # Speed up backend + torch.backends.cudnn.benchmark = True + + has_gpu = torch.cuda.is_available() + device = torch.device('cuda' if has_gpu else 'cpu') + layer_key = config.layer + + if instrumented_model is None: + inst = get_instrumented_model(config.model, config.output_class, layer_key, device) + model = inst.model + else: + print('Reusing InstrumentedModel instance') + inst = instrumented_model + model = inst.model + inst.remove_edits() + model.set_output_class(config.output_class) + + # Regress back to w space + if config.use_w: + print('Using W latent space') + model.use_w() + + inst.retain_layer(layer_key) + model.partial_forward(model.sample_latent(1), layer_key) + sample_shape = inst.retained_features()[layer_key].shape + sample_dims = np.prod(sample_shape) + print('Feature shape:', sample_shape) + + input_shape = inst.model.get_latent_shape() + input_dims = inst.model.get_latent_dims() + + config.components = min(config.components, sample_dims) + transformer = get_estimator(config.estimator, config.components, config.sparsity) + + X = None + X_global_mean = None + + # Figure out batch size if not provided + B = config.batch_size or get_max_batch_size(inst, device, layer_key) + + # Divisible by B (ignored in output name) + N = config.n // B * B + + # Compute maximum batch size based on RAM + pagefile budget + target_bytes = 20 * 1_000_000_000 # GB + feat_size_bytes = sample_dims * np.dtype('float64').itemsize + N_limit_RAM = np.floor_divide(target_bytes, feat_size_bytes) + if not transformer.batch_support and N > N_limit_RAM: + print('WARNING: estimator does not support batching, ' \ + 'given config will use {:.1f} GB memory.'.format(feat_size_bytes / 1_000_000_000 * N)) + + # 32-bit LAPACK gets very unhappy about huge matrices (in linalg.svd) + if config.estimator == 'ica': + lapack_max_N = np.floor_divide(np.iinfo(np.int32).max // 4, sample_dims) # 4x extra buffer + if N > lapack_max_N: + raise RuntimeError(f'Matrices too large for ICA, please use N <= {lapack_max_N}') + + print('B={}, N={}, dims={}, N/dims={:.1f}'.format(B, N, sample_dims, N/sample_dims), flush=True) + + # Must not depend on chosen batch size (reproducibility) + NB = max(B, max(2_000, 3*config.components)) # ipca: as large as possible! + + samples = None + if not transformer.batch_support: + samples = np.zeros((N + NB, sample_dims), dtype=np.float32) + + torch.manual_seed(config.seed or SEED_SAMPLING) + np.random.seed(config.seed or SEED_SAMPLING) + + # Use exactly the same latents regardless of batch size + # Store in main memory, since N might be huge (1M+) + # Run in batches, since sample_latent() might perform Z -> W mapping + n_lat = ((N + NB - 1) // B + 1) * B + latents = np.zeros((n_lat, *input_shape[1:]), dtype=np.float32) + with torch.no_grad(): + for i in trange(n_lat // B, desc='Sampling latents'): + latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy() + + # Decomposition on non-Gaussian latent space + samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W' + + canceled = False + try: + X = np.ones((NB, sample_dims), dtype=np.float32) + action = 'Fitting' if transformer.batch_support else 'Collecting' + for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True): + for mb in range(0, NB, B): + z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device) + + if samples_are_latents: + # Decomposition on latents directly (e.g. StyleGAN W) + batch = z.reshape((B, -1)) + else: + # Decomposition on intermediate layer + with torch.no_grad(): + model.partial_forward(z, layer_key) + + # Permuted to place PCA dimensions last + batch = inst.retained_features()[layer_key].reshape((B, -1)) + + space_left = min(B, NB - mb) + X[mb:mb+space_left] = batch.cpu().numpy()[:space_left] + + if transformer.batch_support: + if not transformer.fit_partial(X.reshape(-1, sample_dims)): + break + else: + samples[gi:gi+NB, :] = X.copy() + except KeyboardInterrupt: + if not transformer.batch_support: + sys.exit(1) # no progress yet + + dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}') + print(f'Saving current state to "{dump_name.name}" before exiting') + canceled = True + + if not transformer.batch_support: + X = samples # Use all samples + X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) # TODO: activations surely multi-modal...! + X -= X_global_mean + + print(f'[{timestamp()}] Fitting whole batch') + t_start_fit = datetime.datetime.now() + + transformer.fit(X) + + print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}') + assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero' + else: + X_global_mean = transformer.transformer.mean_.reshape((1, sample_dims)) + X = X.reshape(-1, sample_dims) + X -= X_global_mean + + X_comp, X_stdev, X_var_ratio = transformer.get_components() + + assert X_comp.shape[1] == sample_dims \ + and X_comp.shape[0] == config.components \ + and X_global_mean.shape[1] == sample_dims \ + and X_stdev.shape[0] == config.components, 'Invalid shape' + + # 'Activations' are really latents in a secondary latent space + if samples_are_latents: + Z_comp = X_comp + Z_global_mean = X_global_mean + else: + Z_comp, Z_global_mean = regression(X_comp, X_global_mean, X_stdev, inst, config) + + # Normalize + Z_comp /= np.linalg.norm(Z_comp, axis=-1, keepdims=True) + + # Random projections + # We expect these to explain much less of the variance + random_dirs = get_random_dirs(config.components, np.prod(sample_shape)) + n_rand_samples = min(5000, X.shape[0]) + X_view = X[:n_rand_samples, :].T + assert np.shares_memory(X_view, X), "Error: slice produced copy" + X_stdev_random = np.dot(random_dirs, X_view).std(axis=1) + + # Inflate back to proper shapes (for easier broadcasting) + X_comp = X_comp.reshape(-1, *sample_shape) + X_global_mean = X_global_mean.reshape(sample_shape) + Z_comp = Z_comp.reshape(-1, *input_shape) + Z_global_mean = Z_global_mean.reshape(input_shape) + + # Compute stdev in latent space if non-Gaussian + lat_stdev = np.ones_like(X_stdev) + if config.use_w: + samples = model.sample_latent(5000).reshape(5000, input_dims).detach().cpu().numpy() + coords = np.dot(Z_comp.reshape(-1, input_dims), samples.T) + lat_stdev = coords.std(axis=1) + + os.makedirs(dump_name.parent, exist_ok=True) + np.savez_compressed(dump_name, **{ + 'act_comp': X_comp.astype(np.float32), + 'act_mean': X_global_mean.astype(np.float32), + 'act_stdev': X_stdev.astype(np.float32), + 'lat_comp': Z_comp.astype(np.float32), + 'lat_mean': Z_global_mean.astype(np.float32), + 'lat_stdev': lat_stdev.astype(np.float32), + 'var_ratio': X_var_ratio.astype(np.float32), + 'random_stdevs': X_stdev_random.astype(np.float32), + }) + + if canceled: + sys.exit(1) + + # Don't shutdown if passed as param + if instrumented_model is None: + inst.close() + del inst + del model + + del X + del X_comp + del random_dirs + del batch + del samples + del latents + torch.cuda.empty_cache() + +# Return cached results or commpute if needed +# Pass existing InstrumentedModel instance to reuse it +def get_or_compute(config, model=None, submit_config=None, force_recompute=False): + if submit_config is None: + wrkdir = str(Path(__file__).parent.resolve()) + submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir) + + # Called directly by run.py + return _compute(submit_config, config, model, force_recompute) + +def _compute(submit_config, config, model=None, force_recompute=False): + basedir = Path(submit_config.run_dir) + outdir = basedir / 'out' + + if config.n is None: + raise RuntimeError('Must specify number of samples with -n=XXX') + + if model and not isinstance(model, InstrumentedModel): + raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"') + + if config.use_w and not 'StyleGAN' in config.model: + raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}') + + transformer = get_estimator(config.estimator, config.components, config.sparsity) + dump_name = "{}-{}_{}_{}_n{}{}{}.npz".format( + config.model.lower(), + config.output_class.replace(' ', '_'), + config.layer.lower(), + transformer.get_param_str(), + config.n, + '_w' if config.use_w else '', + f'_seed{config.seed}' if config.seed else '' + ) + + dump_path = basedir / 'cache' / 'components' / dump_name + + if not dump_path.is_file() or force_recompute: + print('Not cached') + t_start = datetime.datetime.now() + compute(config, dump_path, model) + print('Total time:', datetime.datetime.now() - t_start) + + return dump_path \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..4e0d3ac4d0de9bbac447ba439e890a37ca61e2f7 --- /dev/null +++ b/environment.yml @@ -0,0 +1,25 @@ +name: ganspace +channels: + - defaults + - conda-forge + - pytorch +dependencies: + - python=3.7 + - pytorch::pytorch=1.3 + - pytorch::torchvision + - cudatoolkit=10.1 + - pillow=6.2 + - ffmpeg + - tqdm + - scipy + - scikit-learn + - scikit-image + - boto3 + - requests + - nltk + - pip + - pip: + - fbpca + - pyopengltk + +# conda env update -f environment.yml --prune diff --git a/estimators.py b/estimators.py new file mode 100644 index 0000000000000000000000000000000000000000..470858c8edc85a64f035fe12ceaf37182ecd497f --- /dev/null +++ b/estimators.py @@ -0,0 +1,218 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +from sklearn.decomposition import FastICA, PCA, IncrementalPCA, MiniBatchSparsePCA, SparsePCA, KernelPCA +import fbpca +import numpy as np +import itertools +from types import SimpleNamespace + +# ICA +class ICAEstimator(): + def __init__(self, n_components): + self.n_components = n_components + self.maxiter = 10000 + self.whiten = True # ICA: whitening is essential, should not be skipped + self.transformer = FastICA(n_components, random_state=0, whiten=self.whiten, max_iter=self.maxiter) + self.batch_support = False + self.stdev = np.zeros((n_components,)) + self.total_var = 0.0 + + def get_param_str(self): + return "ica_c{}{}".format(self.n_components, '_w' if self.whiten else '') + + def fit(self, X): + self.transformer.fit(X) + if self.transformer.n_iter_ >= self.maxiter: + raise RuntimeError(f'FastICA did not converge (N={X.shape[0]}, it={self.maxiter})') + + # Normalize components + self.transformer.components_ /= np.sqrt(np.sum(self.transformer.components_**2, axis=-1, keepdims=True)) + + # Save variance for later + self.total_var = X.var(axis=0).sum() + + # Compute projected standard deviations + self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1) + + # Sort components based on explained variance + idx = np.argsort(self.stdev)[::-1] + self.stdev = self.stdev[idx] + self.transformer.components_[:] = self.transformer.components_[idx] + + def get_components(self): + var_ratio = self.stdev**2 / self.total_var + return self.transformer.components_, self.stdev, var_ratio # ICA outputs are not normalized + +# Incremental PCA +class IPCAEstimator(): + def __init__(self, n_components): + self.n_components = n_components + self.whiten = False + self.transformer = IncrementalPCA(n_components, whiten=self.whiten, batch_size=max(100, 2*n_components)) + self.batch_support = True + + def get_param_str(self): + return "ipca_c{}{}".format(self.n_components, '_w' if self.whiten else '') + + def fit(self, X): + self.transformer.fit(X) + + def fit_partial(self, X): + try: + self.transformer.partial_fit(X) + self.transformer.n_samples_seen_ = \ + self.transformer.n_samples_seen_.astype(np.int64) # avoid overflow + return True + except ValueError as e: + print(f'\nIPCA error:', e) + return False + + def get_components(self): + stdev = np.sqrt(self.transformer.explained_variance_) # already sorted + var_ratio = self.transformer.explained_variance_ratio_ + return self.transformer.components_, stdev, var_ratio # PCA outputs are normalized + +# Standard PCA +class PCAEstimator(): + def __init__(self, n_components): + self.n_components = n_components + self.solver = 'full' + self.transformer = PCA(n_components, svd_solver=self.solver) + self.batch_support = False + + def get_param_str(self): + return f"pca-{self.solver}_c{self.n_components}" + + def fit(self, X): + self.transformer.fit(X) + + # Save variance for later + self.total_var = X.var(axis=0).sum() + + # Compute projected standard deviations + self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1) + + # Sort components based on explained variance + idx = np.argsort(self.stdev)[::-1] + self.stdev = self.stdev[idx] + self.transformer.components_[:] = self.transformer.components_[idx] + + # Check orthogonality + dotps = [np.dot(*self.transformer.components_[[i, j]]) + for (i, j) in itertools.combinations(range(self.n_components), 2)] + if not np.allclose(dotps, 0, atol=1e-4): + print('IPCA components not orghogonal, max dot', np.abs(dotps).max()) + + self.transformer.mean_ = X.mean(axis=0, keepdims=True) + + def get_components(self): + var_ratio = self.stdev**2 / self.total_var + return self.transformer.components_, self.stdev, var_ratio + +# Facebook's PCA +# Good default choice: very fast and accurate. +# Very high sample counts won't fit into RAM, +# in which case IncrementalPCA must be used. +class FacebookPCAEstimator(): + def __init__(self, n_components): + self.n_components = n_components + self.transformer = SimpleNamespace() + self.batch_support = False + self.n_iter = 2 + self.l = 2*self.n_components + + def get_param_str(self): + return "fbpca_c{}_it{}_l{}".format(self.n_components, self.n_iter, self.l) + + def fit(self, X): + U, s, Va = fbpca.pca(X, k=self.n_components, n_iter=self.n_iter, raw=True, l=self.l) + self.transformer.components_ = Va + + # Save variance for later + self.total_var = X.var(axis=0).sum() + + # Compute projected standard deviations + self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1) + + # Sort components based on explained variance + idx = np.argsort(self.stdev)[::-1] + self.stdev = self.stdev[idx] + self.transformer.components_[:] = self.transformer.components_[idx] + + # Check orthogonality + dotps = [np.dot(*self.transformer.components_[[i, j]]) + for (i, j) in itertools.combinations(range(self.n_components), 2)] + if not np.allclose(dotps, 0, atol=1e-4): + print('FBPCA components not orghogonal, max dot', np.abs(dotps).max()) + + self.transformer.mean_ = X.mean(axis=0, keepdims=True) + + def get_components(self): + var_ratio = self.stdev**2 / self.total_var + return self.transformer.components_, self.stdev, var_ratio + +# Sparse PCA +# The algorithm is online along the features direction, not the samples direction +# => no partial_fit +class SPCAEstimator(): + def __init__(self, n_components, alpha=10.0): + self.n_components = n_components + self.whiten = False + self.alpha = alpha # higher alpha => sparser components + #self.transformer = MiniBatchSparsePCA(n_components, alpha=alpha, n_iter=100, + # batch_size=max(20, n_components//5), random_state=0, normalize_components=True) + self.transformer = SparsePCA(n_components, alpha=alpha, ridge_alpha=0.01, + max_iter=100, random_state=0, n_jobs=-1, normalize_components=True) # TODO: warm start using PCA result? + self.batch_support = False # maybe through memmap and HDD-stored tensor + self.stdev = np.zeros((n_components,)) + self.total_var = 0.0 + + def get_param_str(self): + return "spca_c{}_a{}{}".format(self.n_components, self.alpha, '_w' if self.whiten else '') + + def fit(self, X): + self.transformer.fit(X) + + # Save variance for later + self.total_var = X.var(axis=0).sum() + + # Compute projected standard deviations + # NB: cannot simply project with dot product! + self.stdev = self.transformer.transform(X).std(axis=0) # X = (n_samples, n_features) + + # Sort components based on explained variance + idx = np.argsort(self.stdev)[::-1] + self.stdev = self.stdev[idx] + self.transformer.components_[:] = self.transformer.components_[idx] + + # Check orthogonality + dotps = [np.dot(*self.transformer.components_[[i, j]]) + for (i, j) in itertools.combinations(range(self.n_components), 2)] + if not np.allclose(dotps, 0, atol=1e-4): + print('SPCA components not orghogonal, max dot', np.abs(dotps).max()) + + def get_components(self): + var_ratio = self.stdev**2 / self.total_var + return self.transformer.components_, self.stdev, var_ratio # SPCA outputs are normalized + +def get_estimator(name, n_components, alpha): + if name == 'pca': + return PCAEstimator(n_components) + if name == 'ipca': + return IPCAEstimator(n_components) + elif name == 'fbpca': + return FacebookPCAEstimator(n_components) + elif name == 'ica': + return ICAEstimator(n_components) + elif name == 'spca': + return SPCAEstimator(n_components, alpha) + else: + raise RuntimeError('Unknown estimator') \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9941a7bb29d1b9a0a00f9cf90ddf2c48f1e38ed9 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +from .wrappers import * \ No newline at end of file diff --git a/models/biggan/__init__.py b/models/biggan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..583509736f3503bc277d5d2e2a69f445f7df8517 --- /dev/null +++ b/models/biggan/__init__.py @@ -0,0 +1,8 @@ +from pathlib import Path +import sys + +module_path = Path(__file__).parent / 'pytorch_biggan' +sys.path.append(str(module_path.resolve())) +from pytorch_pretrained_biggan import * +from pytorch_pretrained_biggan.model import GenBlock +from pytorch_pretrained_biggan.file_utils import http_get, s3_get \ No newline at end of file diff --git a/models/biggan/pytorch_biggan/.gitignore b/models/biggan/pytorch_biggan/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..05ddaa0a3bbca712670120686fcda8001db5ae3f --- /dev/null +++ b/models/biggan/pytorch_biggan/.gitignore @@ -0,0 +1,110 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# vscode +.vscode/ + +# models +models/ \ No newline at end of file diff --git a/models/biggan/pytorch_biggan/LICENSE b/models/biggan/pytorch_biggan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f42fd227a7d2d8baf6637ac59ca80449c2b35812 --- /dev/null +++ b/models/biggan/pytorch_biggan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Erik Härkönen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/models/biggan/pytorch_biggan/MANIFEST.in b/models/biggan/pytorch_biggan/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..1aba38f67a2211cf5b09466d7b411206cb7223bf --- /dev/null +++ b/models/biggan/pytorch_biggan/MANIFEST.in @@ -0,0 +1 @@ +include LICENSE diff --git a/models/biggan/pytorch_biggan/README.md b/models/biggan/pytorch_biggan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..deaa6c2a145a02a211ca45c59541ff88ce4da23c --- /dev/null +++ b/models/biggan/pytorch_biggan/README.md @@ -0,0 +1,227 @@ +# BigStyleGAN +This is a copy of HuggingFace's BigGAN implementation, with the addition of layerwise latent inputs. + +# PyTorch pretrained BigGAN +An op-for-op PyTorch reimplementation of DeepMind's BigGAN model with the pre-trained weights from DeepMind. + +## Introduction + +This repository contains an op-for-op PyTorch reimplementation of DeepMind's BigGAN that was released with the paper [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://openreview.net/forum?id=B1xsqj09Fm) by Andrew Brock, Jeff Donahue and Karen Simonyan. + +This PyTorch implementation of BigGAN is provided with the [pretrained 128x128, 256x256 and 512x512 models by DeepMind](https://tfhub.dev/deepmind/biggan-deep-128/1). We also provide the scripts used to download and convert these models from the TensorFlow Hub models. + +This reimplementation was done from the raw computation graph of the Tensorflow version and behave similarly to the TensorFlow version (variance of the output difference of the order of 1e-5). + +This implementation currently only contains the generator as the weights of the discriminator were not released (although the structure of the discriminator is very similar to the generator so it could be added pretty easily. Tell me if you want to do a PR on that, I would be happy to help.) + +## Installation + +This repo was tested on Python 3.6 and PyTorch 1.0.1 + +PyTorch pretrained BigGAN can be installed from pip as follows: +```bash +pip install pytorch-pretrained-biggan +``` + +If you simply want to play with the GAN this should be enough. + +If you want to use the conversion scripts and the imagenet utilities, additional requirements are needed, in particular TensorFlow and NLTK. To install all the requirements please use the `full_requirements.txt` file: +```bash +git clone https://github.com/huggingface/pytorch-pretrained-BigGAN.git +cd pytorch-pretrained-BigGAN +pip install -r full_requirements.txt +``` + +## Models + +This repository provide direct and simple access to the pretrained "deep" versions of BigGAN for 128, 256 and 512 pixels resolutions as described in the [associated publication](https://openreview.net/forum?id=B1xsqj09Fm). +Here are some details on the models: + +- `BigGAN-deep-128`: a 50.4M parameters model generating 128x128 pixels images, the model dump weights 201 MB, +- `BigGAN-deep-256`: a 55.9M parameters model generating 256x256 pixels images, the model dump weights 224 MB, +- `BigGAN-deep-512`: a 56.2M parameters model generating 512x512 pixels images, the model dump weights 225 MB. + +Please refer to Appendix B of the paper for details on the architectures. + +All models comprise pre-computed batch norm statistics for 51 truncation values between 0 and 1 (see Appendix C.1 in the paper for details). + +## Usage + +Here is a quick-start example using `BigGAN` with a pre-trained model. + +See the [doc section](#doc) below for details on these classes and methods. + +```python +import torch +from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample, + save_as_images, display_in_terminal) + +# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows +import logging +logging.basicConfig(level=logging.INFO) + +# Load pre-trained model tokenizer (vocabulary) +model = BigGAN.from_pretrained('biggan-deep-256') + +# Prepare a input +truncation = 0.4 +class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3) +noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3) + +# All in tensors +noise_vector = torch.from_numpy(noise_vector) +class_vector = torch.from_numpy(class_vector) + +# If you have a GPU, put everything on cuda +noise_vector = noise_vector.to('cuda') +class_vector = class_vector.to('cuda') +model.to('cuda') + +# Generate an image +with torch.no_grad(): + output = model(noise_vector, class_vector, truncation) + +# If you have a GPU put back on CPU +output = output.to('cpu') + +# If you have a sixtel compatible terminal you can display the images in the terminal +# (see https://github.com/saitoha/libsixel for details) +display_in_terminal(output) + +# Save results as png images +save_as_images(output) +``` + +![output_0](assets/output_0.png) +![output_1](assets/output_1.png) +![output_2](assets/output_2.png) + +## Doc + +### Loading DeepMind's pre-trained weights + +To load one of DeepMind's pre-trained models, instantiate a `BigGAN` model with `from_pretrained()` as: + +```python +model = BigGAN.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None) +``` + +where + +- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either: + + - the shortcut name of a Google AI's or OpenAI's pre-trained model selected in the list: + + - `biggan-deep-128`: 12-layer, 768-hidden, 12-heads, 110M parameters + - `biggan-deep-256`: 24-layer, 1024-hidden, 16-heads, 340M parameters + - `biggan-deep-512`: 12-layer, 768-hidden, 12-heads , 110M parameters + + - a path or url to a pretrained model archive containing: + + - `config.json`: a configuration file for the model, and + - `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BigGAN` (saved with the usual `torch.save()`). + + If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_biggan/model.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_biggan/`). +- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. + +### Configuration + +`BigGANConfig` is a class to store and load BigGAN configurations. It's defined in [`config.py`](./pytorch_pretrained_biggan/config.py). + +Here are some details on the attributes: + +- `output_dim`: output resolution of the GAN (128, 256 or 512) for the pre-trained models, +- `z_dim`: size of the noise vector (128 for the pre-trained models). +- `class_embed_dim`: size of the class embedding vectors (128 for the pre-trained models). +- `channel_width`: size of each channel (128 for the pre-trained models). +- `num_classes`: number of classes in the training dataset, like imagenet (1000 for the pre-trained models). +- `layers`: A list of layers definition. Each definition for a layer is a triple of [up-sample in the layer ? (bool), number of input channels (int), number of output channels (int)] +- `attention_layer_position`: Position of the self-attention layer in the layer hierarchy (8 for the pre-trained models). +- `eps`: epsilon value to use for spectral and batch normalization layers (1e-4 for the pre-trained models). +- `n_stats`: number of pre-computed statistics for the batch normalization layers associated to various truncation values between 0 and 1 (51 for the pre-trained models). + +### Model + +`BigGAN` is a PyTorch model (`torch.nn.Module`) of BigGAN defined in [`model.py`](./pytorch_pretrained_biggan/model.py). This model comprises the class embeddings (a linear layer) and the generator with a series of convolutions and conditional batch norms. The discriminator is currently not implemented since pre-trained weights have not been released for it. + +The inputs and output are **identical to the TensorFlow model inputs and outputs**. + +We detail them here. + +`BigGAN` takes as *inputs*: + +- `z`: a torch.FloatTensor of shape [batch_size, config.z_dim] with noise sampled from a truncated normal distribution, and +- `class_label`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). +- `truncation`: a float between 0 (not comprised) and 1. The truncation of the truncated normal used for creating the noise vector. This truncation value is used to selecte between a set of pre-computed statistics (means and variances) for the batch norm layers. + +`BigGAN` *outputs* an array of shape [batch_size, 3, resolution, resolution] where resolution is 128, 256 or 512 depending of the model: + +### Utilities: Images, Noise, Imagenet classes + +We provide a few utility method to use the model. They are defined in [`utils.py`](./pytorch_pretrained_biggan/utils.py). + +Here are some details on these methods: + +- `truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None)`: + + Create a truncated noise vector. + - Params: + - batch_size: batch size. + - dim_z: dimension of z + - truncation: truncation value to use + - seed: seed for the random generator + - Output: + array of shape (batch_size, dim_z) + +- `convert_to_images(obj)`: + + Convert an output tensor from BigGAN in a list of images. + - Params: + - obj: tensor or numpy array of shape (batch_size, channels, height, width) + - Output: + - list of Pillow Images of size (height, width) + +- `save_as_images(obj, file_name='output')`: + + Convert and save an output tensor from BigGAN in a list of saved images. + - Params: + - obj: tensor or numpy array of shape (batch_size, channels, height, width) + - file_name: path and beggingin of filename to save. + Images will be saved as `file_name_{image_number}.png` + +- `display_in_terminal(obj)`: + + Convert and display an output tensor from BigGAN in the terminal. This function use `libsixel` and will only work in a libsixel-compatible terminal. Please refer to https://github.com/saitoha/libsixel for more details. + - Params: + - obj: tensor or numpy array of shape (batch_size, channels, height, width) + - file_name: path and beggingin of filename to save. + Images will be saved as `file_name_{image_number}.png` + +- `one_hot_from_int(int_or_list, batch_size=1)`: + + Create a one-hot vector from a class index or a list of class indices. + - Params: + - int_or_list: int, or list of int, of the imagenet classes (between 0 and 999) + - batch_size: batch size. + - If int_or_list is an int create a batch of identical classes. + - If int_or_list is a list, we should have `len(int_or_list) == batch_size` + - Output: + - array of shape (batch_size, 1000) + +- `one_hot_from_names(class_name, batch_size=1)`: + + Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...). We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one. If we can't find it direcly, we look at the hyponyms and hypernyms of the class name. + - Params: + - class_name: string containing the name of an imagenet object. + - Output: + - array of shape (batch_size, 1000) + +## Download and conversion scripts + +Scripts to download and convert the TensorFlow models from TensorFlow Hub are provided in [./scripts](./scripts/). + +The scripts can be used directly as: +```bash +./scripts/download_tf_hub_models.sh +./scripts/convert_tf_hub_models.sh +``` diff --git a/models/biggan/pytorch_biggan/full_requirements.txt b/models/biggan/pytorch_biggan/full_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f2dee70711e7f07b644d83d776cb4a2503999ff7 --- /dev/null +++ b/models/biggan/pytorch_biggan/full_requirements.txt @@ -0,0 +1,5 @@ +tensorflow +tensorflow-hub +Pillow +nltk +libsixel-python \ No newline at end of file diff --git a/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/__init__.py b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b570848421afd921fae635569c97d0f8f5b33c80 --- /dev/null +++ b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/__init__.py @@ -0,0 +1,6 @@ +from .config import BigGANConfig +from .model import BigGAN +from .file_utils import PYTORCH_PRETRAINED_BIGGAN_CACHE, cached_path +from .utils import (truncated_noise_sample, save_as_images, + convert_to_images, display_in_terminal, + one_hot_from_int, one_hot_from_names) diff --git a/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/config.py b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..454236a4bfa0d11fda0d52e0ce9b2926f8c32d30 --- /dev/null +++ b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/config.py @@ -0,0 +1,70 @@ +# coding: utf-8 +""" +BigGAN config. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import copy +import json + +class BigGANConfig(object): + """ Configuration class to store the configuration of a `BigGAN`. + Defaults are for the 128x128 model. + layers tuple are (up-sample in the layer ?, input channels, output channels) + """ + def __init__(self, + output_dim=128, + z_dim=128, + class_embed_dim=128, + channel_width=128, + num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1)], + attention_layer_position=8, + eps=1e-4, + n_stats=51): + """Constructs BigGANConfig. """ + self.output_dim = output_dim + self.z_dim = z_dim + self.class_embed_dim = class_embed_dim + self.channel_width = channel_width + self.num_classes = num_classes + self.layers = layers + self.attention_layer_position = attention_layer_position + self.eps = eps + self.n_stats = n_stats + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BigGANConfig` from a Python dictionary of parameters.""" + config = BigGANConfig() + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BigGANConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" diff --git a/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/convert_tf_to_pytorch.py b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/convert_tf_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..7ccb787dec188e9dbd9ea31288c049c1bdb30f95 --- /dev/null +++ b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/convert_tf_to_pytorch.py @@ -0,0 +1,312 @@ +# coding: utf-8 +""" +Convert a TF Hub model for BigGAN in a PT one. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +from itertools import chain + +import os +import argparse +import logging +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.functional import normalize + +from .model import BigGAN, WEIGHTS_NAME, CONFIG_NAME +from .config import BigGANConfig + +logger = logging.getLogger(__name__) + + +def extract_batch_norm_stats(tf_model_path, batch_norm_stats_path=None): + try: + import numpy as np + import tensorflow as tf + import tensorflow_hub as hub + except ImportError: + raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow and TF Hub to be installed. " + "Please see https://www.tensorflow.org/install/ for installation instructions for TensorFlow. " + "And see https://github.com/tensorflow/hub for installing Hub. " + "Probably pip install tensorflow tensorflow-hub") + tf.reset_default_graph() + logger.info('Loading BigGAN module from: {}'.format(tf_model_path)) + module = hub.Module(tf_model_path) + inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) + for k, v in module.get_input_info_dict().items()} + output = module(inputs) + + initializer = tf.global_variables_initializer() + sess = tf.Session() + stacks = sum(((i*10 + 1, i*10 + 3, i*10 + 6, i*10 + 8) for i in range(50)), ()) + numpy_stacks = [] + for i in stacks: + logger.info("Retrieving module_apply_default/stack_{}".format(i)) + try: + stack_var = tf.get_default_graph().get_tensor_by_name("module_apply_default/stack_%d:0" % i) + except KeyError: + break # We have all the stats + numpy_stacks.append(sess.run(stack_var)) + + if batch_norm_stats_path is not None: + torch.save(numpy_stacks, batch_norm_stats_path) + else: + return numpy_stacks + + +def build_tf_to_pytorch_map(model, config): + """ Build a map from TF variables to PyTorch modules. """ + tf_to_pt_map = {} + + # Embeddings and GenZ + tf_to_pt_map.update({'linear/w/ema_0.9999': model.embeddings.weight, + 'Generator/GenZ/G_linear/b/ema_0.9999': model.generator.gen_z.bias, + 'Generator/GenZ/G_linear/w/ema_0.9999': model.generator.gen_z.weight_orig, + 'Generator/GenZ/G_linear/u0': model.generator.gen_z.weight_u}) + + # GBlock blocks + model_layer_idx = 0 + for i, (up, in_channels, out_channels) in enumerate(config.layers): + if i == config.attention_layer_position: + model_layer_idx += 1 + layer_str = "Generator/GBlock_%d/" % i if i > 0 else "Generator/GBlock/" + layer_pnt = model.generator.layers[model_layer_idx] + for i in range(4): # Batchnorms + batch_str = layer_str + ("BatchNorm_%d/" % i if i > 0 else "BatchNorm/") + batch_pnt = getattr(layer_pnt, 'bn_%d' % i) + for name in ('offset', 'scale'): + sub_module_str = batch_str + name + "/" + sub_module_pnt = getattr(batch_pnt, name) + tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig, + sub_module_str + "u0": sub_module_pnt.weight_u}) + for i in range(4): # Convolutions + conv_str = layer_str + "conv%d/" % i + conv_pnt = getattr(layer_pnt, 'conv_%d' % i) + tf_to_pt_map.update({conv_str + "b/ema_0.9999": conv_pnt.bias, + conv_str + "w/ema_0.9999": conv_pnt.weight_orig, + conv_str + "u0": conv_pnt.weight_u}) + model_layer_idx += 1 + + # Attention block + layer_str = "Generator/attention/" + layer_pnt = model.generator.layers[config.attention_layer_position] + tf_to_pt_map.update({layer_str + "gamma/ema_0.9999": layer_pnt.gamma}) + for pt_name, tf_name in zip(['snconv1x1_g', 'snconv1x1_o_conv', 'snconv1x1_phi', 'snconv1x1_theta'], + ['g/', 'o_conv/', 'phi/', 'theta/']): + sub_module_str = layer_str + tf_name + sub_module_pnt = getattr(layer_pnt, pt_name) + tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig, + sub_module_str + "u0": sub_module_pnt.weight_u}) + + # final batch norm and conv to rgb + layer_str = "Generator/BatchNorm/" + layer_pnt = model.generator.bn + tf_to_pt_map.update({layer_str + "offset/ema_0.9999": layer_pnt.bias, + layer_str + "scale/ema_0.9999": layer_pnt.weight}) + layer_str = "Generator/conv_to_rgb/" + layer_pnt = model.generator.conv_to_rgb + tf_to_pt_map.update({layer_str + "b/ema_0.9999": layer_pnt.bias, + layer_str + "w/ema_0.9999": layer_pnt.weight_orig, + layer_str + "u0": layer_pnt.weight_u}) + return tf_to_pt_map + + +def load_tf_weights_in_biggan(model, config, tf_model_path, batch_norm_stats_path=None): + """ Load tf checkpoints and standing statistics in a pytorch model + """ + try: + import numpy as np + import tensorflow as tf + except ImportError: + raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + # Load weights from TF model + checkpoint_path = tf_model_path + "/variables/variables" + init_vars = tf.train.list_variables(checkpoint_path) + from pprint import pprint + pprint(init_vars) + + # Extract batch norm statistics from model if needed + if batch_norm_stats_path: + stats = torch.load(batch_norm_stats_path) + else: + logger.info("Extracting batch norm stats") + stats = extract_batch_norm_stats(tf_model_path) + + # Build TF to PyTorch weights loading map + tf_to_pt_map = build_tf_to_pytorch_map(model, config) + + tf_weights = {} + for name in tf_to_pt_map.keys(): + array = tf.train.load_variable(checkpoint_path, name) + tf_weights[name] = array + # logger.info("Loading TF weight {} with shape {}".format(name, array.shape)) + + # Load parameters + with torch.no_grad(): + pt_params_pnt = set() + for name, pointer in tf_to_pt_map.items(): + array = tf_weights[name] + if pointer.dim() == 1: + if pointer.dim() < array.ndim: + array = np.squeeze(array) + elif pointer.dim() == 2: # Weights + array = np.transpose(array) + elif pointer.dim() == 4: # Convolutions + array = np.transpose(array, (3, 2, 0, 1)) + else: + raise "Wrong dimensions to adjust: " + str((pointer.shape, array.shape)) + if pointer.shape != array.shape: + raise ValueError("Wrong dimensions: " + str((pointer.shape, array.shape))) + logger.info("Initialize PyTorch weight {} with shape {}".format(name, pointer.shape)) + pointer.data = torch.from_numpy(array) if isinstance(array, np.ndarray) else torch.tensor(array) + tf_weights.pop(name, None) + pt_params_pnt.add(pointer.data_ptr()) + + # Prepare SpectralNorm buffers by running one step of Spectral Norm (no need to train the model): + for module in model.modules(): + for n, buffer in module.named_buffers(): + if n == 'weight_v': + weight_mat = module.weight_orig + weight_mat = weight_mat.reshape(weight_mat.size(0), -1) + u = module.weight_u + + v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=config.eps) + buffer.data = v + pt_params_pnt.add(buffer.data_ptr()) + + u = normalize(torch.mv(weight_mat, v), dim=0, eps=config.eps) + module.weight_u.data = u + pt_params_pnt.add(module.weight_u.data_ptr()) + + # Load batch norm statistics + index = 0 + for layer in model.generator.layers: + if not hasattr(layer, 'bn_0'): + continue + for i in range(4): # Batchnorms + bn_pointer = getattr(layer, 'bn_%d' % i) + pointer = bn_pointer.running_means + if pointer.shape != stats[index].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index]) + pt_params_pnt.add(pointer.data_ptr()) + + pointer = bn_pointer.running_vars + if pointer.shape != stats[index+1].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index+1]) + pt_params_pnt.add(pointer.data_ptr()) + + index += 2 + + bn_pointer = model.generator.bn + pointer = bn_pointer.running_means + if pointer.shape != stats[index].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index]) + pt_params_pnt.add(pointer.data_ptr()) + + pointer = bn_pointer.running_vars + if pointer.shape != stats[index+1].shape: + raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) + pointer.data = torch.from_numpy(stats[index+1]) + pt_params_pnt.add(pointer.data_ptr()) + + remaining_params = list(n for n, t in chain(model.named_parameters(), model.named_buffers()) \ + if t.data_ptr() not in pt_params_pnt) + + logger.info("TF Weights not copied to PyTorch model: {} -".format(', '.join(tf_weights.keys()))) + logger.info("Remanining parameters/buffers from PyTorch model: {} -".format(', '.join(remaining_params))) + + return model + + +BigGAN128 = BigGANConfig(output_dim=128, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1)], + attention_layer_position=8, eps=1e-4, n_stats=51) + +BigGAN256 = BigGANConfig(output_dim=256, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1)], + attention_layer_position=8, eps=1e-4, n_stats=51) + +BigGAN512 = BigGANConfig(output_dim=512, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, + layers=[(False, 16, 16), + (True, 16, 16), + (False, 16, 16), + (True, 16, 8), + (False, 8, 8), + (True, 8, 8), + (False, 8, 8), + (True, 8, 4), + (False, 4, 4), + (True, 4, 2), + (False, 2, 2), + (True, 2, 1), + (False, 1, 1), + (True, 1, 1)], + attention_layer_position=8, eps=1e-4, n_stats=51) + + +def main(): + parser = argparse.ArgumentParser(description="Convert a BigGAN TF Hub model in a PyTorch model") + parser.add_argument("--model_type", type=str, default="", required=True, + help="BigGAN model type (128, 256, 512)") + parser.add_argument("--tf_model_path", type=str, default="", required=True, + help="Path of the downloaded TF Hub model") + parser.add_argument("--pt_save_path", type=str, default="", + help="Folder to save the PyTorch model (default: Folder of the TF Hub model)") + parser.add_argument("--batch_norm_stats_path", type=str, default="", + help="Path of previously extracted batch norm statistics") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + if not args.pt_save_path: + args.pt_save_path = args.tf_model_path + + if args.model_type == "128": + config = BigGAN128 + elif args.model_type == "256": + config = BigGAN256 + elif args.model_type == "512": + config = BigGAN512 + else: + raise ValueError("model_type should be one of 128, 256 or 512") + + model = BigGAN(config) + model = load_tf_weights_in_biggan(model, config, args.tf_model_path, args.batch_norm_stats_path) + + model_save_path = os.path.join(args.pt_save_path, WEIGHTS_NAME) + config_save_path = os.path.join(args.pt_save_path, CONFIG_NAME) + + logger.info("Save model dump to {}".format(model_save_path)) + torch.save(model.state_dict(), model_save_path) + logger.info("Save configuration file to {}".format(config_save_path)) + with open(config_save_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + +if __name__ == "__main__": + main() diff --git a/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/file_utils.py b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41624cad6d7b44c028f3ef1fb541add4956b4601 --- /dev/null +++ b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/file_utils.py @@ -0,0 +1,249 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import json +import logging +import os +import shutil +import tempfile +from functools import wraps +from hashlib import sha256 +import sys +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE', + Path.home() / '.pytorch_pretrained_biggan')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError("HEAD request failed for url {} with status code {}" + .format(url, response.status_code)) + etag = response.headers.get("ETag") + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w', encoding="utf-8") as meta_file: + json.dump(meta, meta_file) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/model.py b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/model.py new file mode 100644 index 0000000000000000000000000000000000000000..22488abd92182a878fa1bedadfed50afbb472d3e --- /dev/null +++ b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/model.py @@ -0,0 +1,345 @@ +# coding: utf-8 +""" BigGAN PyTorch model. + From "Large Scale GAN Training for High Fidelity Natural Image Synthesis" + By Andrew Brocky, Jeff Donahuey and Karen Simonyan. + https://openreview.net/forum?id=B1xsqj09Fm + + PyTorch version implemented from the computational graph of the TF Hub module for BigGAN. + Some part of the code are adapted from https://github.com/brain-research/self-attention-gan + + This version only comprises the generator (since the discriminator's weights are not released). + This version only comprises the "deep" version of BigGAN (see publication). + + Modified by Erik Härkönen: + * Added support for per-layer latent vectors +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import os +import logging +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import BigGANConfig +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin", + 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin", + 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin", +} + +PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json", + 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json", + 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json", +} + +WEIGHTS_NAME = 'pytorch_model.bin' +CONFIG_NAME = 'config.json' + + +def snconv2d(eps=1e-12, **kwargs): + return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps) + +def snlinear(eps=1e-12, **kwargs): + return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps) + +def sn_embedding(eps=1e-12, **kwargs): + return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps) + +class SelfAttn(nn.Module): + """ Self attention Layer""" + def __init__(self, in_channels, eps=1e-12): + super(SelfAttn, self).__init__() + self.in_channels = in_channels + self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, + kernel_size=1, bias=False, eps=eps) + self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, + kernel_size=1, bias=False, eps=eps) + self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, + kernel_size=1, bias=False, eps=eps) + self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels, + kernel_size=1, bias=False, eps=eps) + self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) + self.softmax = nn.Softmax(dim=-1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + _, ch, h, w = x.size() + # Theta path + theta = self.snconv1x1_theta(x) + theta = theta.view(-1, ch//8, h*w) + # Phi path + phi = self.snconv1x1_phi(x) + phi = self.maxpool(phi) + phi = phi.view(-1, ch//8, h*w//4) + # Attn map + attn = torch.bmm(theta.permute(0, 2, 1), phi) + attn = self.softmax(attn) + # g path + g = self.snconv1x1_g(x) + g = self.maxpool(g) + g = g.view(-1, ch//2, h*w//4) + # Attn_g - o_conv + attn_g = torch.bmm(g, attn.permute(0, 2, 1)) + attn_g = attn_g.view(-1, ch//2, h, w) + attn_g = self.snconv1x1_o_conv(attn_g) + # Out + out = x + self.gamma*attn_g + return out + + +class BigGANBatchNorm(nn.Module): + """ This is a batch norm module that can handle conditional input and can be provided with pre-computed + activation means and variances for various truncation parameters. + + We cannot just rely on torch.batch_norm since it cannot handle + batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances. + If you want to train this model you should add running means and variance computation logic. + """ + def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True): + super(BigGANBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.conditional = conditional + + # We use pre-computed statistics for n_stats values of truncation between 0 and 1 + self.register_buffer('running_means', torch.zeros(n_stats, num_features)) + self.register_buffer('running_vars', torch.ones(n_stats, num_features)) + self.step_size = 1.0 / (n_stats - 1) + + if conditional: + assert condition_vector_dim is not None + self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) + self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) + else: + self.weight = torch.nn.Parameter(torch.Tensor(num_features)) + self.bias = torch.nn.Parameter(torch.Tensor(num_features)) + + def forward(self, x, truncation, condition_vector=None): + # Retreive pre-computed statistics associated to this truncation + coef, start_idx = math.modf(truncation / self.step_size) + start_idx = int(start_idx) + if coef != 0.0: # Interpolate + running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef) + running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef) + else: + running_mean = self.running_means[start_idx] + running_var = self.running_vars[start_idx] + + if self.conditional: + running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + + weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1) + bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1) + + out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias + else: + out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, + training=False, momentum=0.0, eps=self.eps) + + return out + + +class GenBlock(nn.Module): + def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False, + n_stats=51, eps=1e-12): + super(GenBlock, self).__init__() + self.up_sample = up_sample + self.drop_channels = (in_size != out_size) + middle_size = in_size // reduction_factor + + self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps) + + self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) + + self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) + + self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) + self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps) + + self.relu = nn.ReLU() + + def forward(self, x, cond_vector, truncation): + x0 = x + + x = self.bn_0(x, truncation, cond_vector) + x = self.relu(x) + x = self.conv_0(x) + + x = self.bn_1(x, truncation, cond_vector) + x = self.relu(x) + if self.up_sample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv_1(x) + + x = self.bn_2(x, truncation, cond_vector) + x = self.relu(x) + x = self.conv_2(x) + + x = self.bn_3(x, truncation, cond_vector) + x = self.relu(x) + x = self.conv_3(x) + + if self.drop_channels: + new_channels = x0.shape[1] // 2 + x0 = x0[:, :new_channels, ...] + if self.up_sample: + x0 = F.interpolate(x0, scale_factor=2, mode='nearest') + + out = x + x0 + return out + +class Generator(nn.Module): + def __init__(self, config): + super(Generator, self).__init__() + self.config = config + ch = config.channel_width + condition_vector_dim = config.z_dim * 2 + + self.gen_z = snlinear(in_features=condition_vector_dim, + out_features=4 * 4 * 16 * ch, eps=config.eps) + + layers = [] + for i, layer in enumerate(config.layers): + if i == config.attention_layer_position: + layers.append(SelfAttn(ch*layer[1], eps=config.eps)) + layers.append(GenBlock(ch*layer[1], + ch*layer[2], + condition_vector_dim, + up_sample=layer[0], + n_stats=config.n_stats, + eps=config.eps)) + self.layers = nn.ModuleList(layers) + + self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False) + self.relu = nn.ReLU() + self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps) + self.tanh = nn.Tanh() + + def forward(self, cond_vector, truncation): + z = self.gen_z(cond_vector[0]) + + # We use this conversion step to be able to use TF weights: + # TF convention on shape is [batch, height, width, channels] + # PT convention on shape is [batch, channels, height, width] + z = z.view(-1, 4, 4, 16 * self.config.channel_width) + z = z.permute(0, 3, 1, 2).contiguous() + + cond_idx = 1 + for i, layer in enumerate(self.layers): + if isinstance(layer, GenBlock): + z = layer(z, cond_vector[cond_idx], truncation) + cond_idx += 1 + else: + z = layer(z) + + z = self.bn(z, truncation) + z = self.relu(z) + z = self.conv_to_rgb(z) + z = z[:, :3, ...] + z = self.tanh(z) + return z + +class BigGAN(nn.Module): + """BigGAN Generator.""" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + + try: + resolved_model_file = cached_path(model_file, cache_dir=cache_dir) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error("Wrong model name, should be a valid path to a folder containing " + "a {} file and a {} file or a model name in {}".format( + WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys())) + raise + + logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file)) + + # Load config + config = BigGANConfig.from_json_file(resolved_config_file) + logger.info("Model config {}".format(config)) + + # Instantiate model. + model = cls(config, *inputs, **kwargs) + state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None) + model.load_state_dict(state_dict, strict=False) + return model + + def __init__(self, config): + super(BigGAN, self).__init__() + self.config = config + self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False) + self.generator = Generator(config) + self.n_latents = len(config.layers) + 1 # one for gen_z + one per layer + + def forward(self, z, class_label, truncation): + assert 0 < truncation <= 1 + + if not isinstance(z, list): + z = self.n_latents*[z] + + if isinstance(class_label, list): + embed = [self.embeddings(l) for l in class_label] + else: + embed = self.n_latents*[self.embeddings(class_label)] + + assert len(z) == self.n_latents, f'Expected {self.n_latents} latents, got {len(z)}' + assert len(embed) == self.n_latents, f'Expected {self.n_latents} class vectors, got {len(class_label)}' + + cond_vectors = [torch.cat((z, e), dim=1) for (z, e) in zip(z, embed)] + z = self.generator(cond_vectors, truncation) + return z + + +if __name__ == "__main__": + import PIL + from .utils import truncated_noise_sample, save_as_images, one_hot_from_names + from .convert_tf_to_pytorch import load_tf_weights_in_biggan + + load_cache = False + cache_path = './saved_model.pt' + config = BigGANConfig() + model = BigGAN(config) + if not load_cache: + model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin') + torch.save(model.state_dict(), cache_path) + else: + model.load_state_dict(torch.load(cache_path)) + + model.eval() + + truncation = 0.4 + noise = truncated_noise_sample(batch_size=2, truncation=truncation) + label = one_hot_from_names('diver', batch_size=2) + + # Tests + # noise = np.zeros((1, 128)) + # label = [983] + + noise = torch.tensor(noise, dtype=torch.float) + label = torch.tensor(label, dtype=torch.float) + with torch.no_grad(): + outputs = model(noise, label, truncation) + print(outputs.shape) + + save_as_images(outputs) diff --git a/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/utils.py b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9edbef3ecc9bf85092f4e670eb5fac8a3b4616 --- /dev/null +++ b/models/biggan/pytorch_biggan/pytorch_pretrained_biggan/utils.py @@ -0,0 +1,216 @@ +# coding: utf-8 +""" BigGAN utilities to prepare truncated noise samples and convert/save/display output images. + Also comprise ImageNet utilities to prepare one hot input vectors for ImageNet classes. + We use Wordnet so you can just input a name in a string and automatically get a corresponding + imagenet class if it exists (or a hypo/hypernym exists in imagenet). +""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +from io import BytesIO + +import numpy as np +from scipy.stats import truncnorm + +logger = logging.getLogger(__name__) + +NUM_CLASSES = 1000 + + +def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): + """ Create a truncated noise vector. + Params: + batch_size: batch size. + dim_z: dimension of z + truncation: truncation value to use + seed: seed for the random generator + Output: + array of shape (batch_size, dim_z) + """ + state = None if seed is None else np.random.RandomState(seed) + values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) + return truncation * values + + +def convert_to_images(obj): + """ Convert an output tensor from BigGAN in a list of images. + Params: + obj: tensor or numpy array of shape (batch_size, channels, height, width) + Output: + list of Pillow Images of size (height, width) + """ + try: + import PIL + except ImportError: + raise ImportError("Please install Pillow to use images: pip install Pillow") + + if not isinstance(obj, np.ndarray): + obj = obj.detach().numpy() + + obj = obj.transpose((0, 2, 3, 1)) + obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255) + + img = [] + for i, out in enumerate(obj): + out_array = np.asarray(np.uint8(out), dtype=np.uint8) + img.append(PIL.Image.fromarray(out_array)) + return img + + +def save_as_images(obj, file_name='output'): + """ Convert and save an output tensor from BigGAN in a list of saved images. + Params: + obj: tensor or numpy array of shape (batch_size, channels, height, width) + file_name: path and beggingin of filename to save. + Images will be saved as `file_name_{image_number}.png` + """ + img = convert_to_images(obj) + + for i, out in enumerate(img): + current_file_name = file_name + '_%d.png' % i + logger.info("Saving image to {}".format(current_file_name)) + out.save(current_file_name, 'png') + + +def display_in_terminal(obj): + """ Convert and display an output tensor from BigGAN in the terminal. + This function use `libsixel` and will only work in a libsixel-compatible terminal. + Please refer to https://github.com/saitoha/libsixel for more details. + + Params: + obj: tensor or numpy array of shape (batch_size, channels, height, width) + file_name: path and beggingin of filename to save. + Images will be saved as `file_name_{image_number}.png` + """ + try: + import PIL + from libsixel import (sixel_output_new, sixel_dither_new, sixel_dither_initialize, + sixel_dither_set_palette, sixel_dither_set_pixelformat, + sixel_dither_get, sixel_encode, sixel_dither_unref, + sixel_output_unref, SIXEL_PIXELFORMAT_RGBA8888, + SIXEL_PIXELFORMAT_RGB888, SIXEL_PIXELFORMAT_PAL8, + SIXEL_PIXELFORMAT_G8, SIXEL_PIXELFORMAT_G1) + except ImportError: + raise ImportError("Display in Terminal requires Pillow, libsixel " + "and a libsixel compatible terminal. " + "Please read info at https://github.com/saitoha/libsixel " + "and install with pip install Pillow libsixel-python") + + s = BytesIO() + + images = convert_to_images(obj) + widths, heights = zip(*(i.size for i in images)) + + output_width = sum(widths) + output_height = max(heights) + + output_image = PIL.Image.new('RGB', (output_width, output_height)) + + x_offset = 0 + for im in images: + output_image.paste(im, (x_offset,0)) + x_offset += im.size[0] + + try: + data = output_image.tobytes() + except NotImplementedError: + data = output_image.tostring() + output = sixel_output_new(lambda data, s: s.write(data), s) + + try: + if output_image.mode == 'RGBA': + dither = sixel_dither_new(256) + sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGBA8888) + elif output_image.mode == 'RGB': + dither = sixel_dither_new(256) + sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGB888) + elif output_image.mode == 'P': + palette = output_image.getpalette() + dither = sixel_dither_new(256) + sixel_dither_set_palette(dither, palette) + sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_PAL8) + elif output_image.mode == 'L': + dither = sixel_dither_get(SIXEL_BUILTIN_G8) + sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G8) + elif output_image.mode == '1': + dither = sixel_dither_get(SIXEL_BUILTIN_G1) + sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G1) + else: + raise RuntimeError('unexpected output_image mode') + try: + sixel_encode(data, output_width, output_height, 1, dither, output) + print(s.getvalue().decode('ascii')) + finally: + sixel_dither_unref(dither) + finally: + sixel_output_unref(output) + + +def one_hot_from_int(int_or_list, batch_size=1): + """ Create a one-hot vector from a class index or a list of class indices. + Params: + int_or_list: int, or list of int, of the imagenet classes (between 0 and 999) + batch_size: batch size. + If int_or_list is an int create a batch of identical classes. + If int_or_list is a list, we should have `len(int_or_list) == batch_size` + Output: + array of shape (batch_size, 1000) + """ + if isinstance(int_or_list, int): + int_or_list = [int_or_list] + + if len(int_or_list) == 1 and batch_size > 1: + int_or_list = [int_or_list[0]] * batch_size + + assert batch_size == len(int_or_list) + + array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32) + for i, j in enumerate(int_or_list): + array[i, j] = 1.0 + return array + + +def one_hot_from_names(class_name_or_list, batch_size=1): + """ Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...). + We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one. + If we can't find it direcly, we look at the hyponyms and hypernyms of the class name. + + Params: + class_name_or_list: string containing the name of an imagenet object or a list of such strings (for a batch). + Output: + array of shape (batch_size, 1000) + """ + try: + from nltk.corpus import wordnet as wn + except ImportError: + raise ImportError("You need to install nltk to use this function") + + if not isinstance(class_name_or_list, (list, tuple)): + class_name_or_list = [class_name_or_list] + else: + batch_size = max(batch_size, len(class_name_or_list)) + + classes = [] + for class_name in class_name_or_list: + class_name = class_name.replace(" ", "_") + + original_synsets = wn.synsets(class_name) + original_synsets = list(filter(lambda s: s.pos() == 'n', original_synsets)) # keep only names + if not original_synsets: + return None + + possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, original_synsets)) + if possible_synsets: + classes.append(IMAGENET[possible_synsets[0].offset()]) + else: + # try hypernyms and hyponyms + possible_synsets = sum([s.hypernyms() + s.hyponyms() for s in original_synsets], []) + possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, possible_synsets)) + if possible_synsets: + classes.append(IMAGENET[possible_synsets[0].offset()]) + + return one_hot_from_int(classes, batch_size=batch_size) + + +IMAGENET = {1440764: 0, 1443537: 1, 1484850: 2, 1491361: 3, 1494475: 4, 1496331: 5, 1498041: 6, 1514668: 7, 1514859: 8, 1518878: 9, 1530575: 10, 1531178: 11, 1532829: 12, 1534433: 13, 1537544: 14, 1558993: 15, 1560419: 16, 1580077: 17, 1582220: 18, 1592084: 19, 1601694: 20, 1608432: 21, 1614925: 22, 1616318: 23, 1622779: 24, 1629819: 25, 1630670: 26, 1631663: 27, 1632458: 28, 1632777: 29, 1641577: 30, 1644373: 31, 1644900: 32, 1664065: 33, 1665541: 34, 1667114: 35, 1667778: 36, 1669191: 37, 1675722: 38, 1677366: 39, 1682714: 40, 1685808: 41, 1687978: 42, 1688243: 43, 1689811: 44, 1692333: 45, 1693334: 46, 1694178: 47, 1695060: 48, 1697457: 49, 1698640: 50, 1704323: 51, 1728572: 52, 1728920: 53, 1729322: 54, 1729977: 55, 1734418: 56, 1735189: 57, 1737021: 58, 1739381: 59, 1740131: 60, 1742172: 61, 1744401: 62, 1748264: 63, 1749939: 64, 1751748: 65, 1753488: 66, 1755581: 67, 1756291: 68, 1768244: 69, 1770081: 70, 1770393: 71, 1773157: 72, 1773549: 73, 1773797: 74, 1774384: 75, 1774750: 76, 1775062: 77, 1776313: 78, 1784675: 79, 1795545: 80, 1796340: 81, 1797886: 82, 1798484: 83, 1806143: 84, 1806567: 85, 1807496: 86, 1817953: 87, 1818515: 88, 1819313: 89, 1820546: 90, 1824575: 91, 1828970: 92, 1829413: 93, 1833805: 94, 1843065: 95, 1843383: 96, 1847000: 97, 1855032: 98, 1855672: 99, 1860187: 100, 1871265: 101, 1872401: 102, 1873310: 103, 1877812: 104, 1882714: 105, 1883070: 106, 1910747: 107, 1914609: 108, 1917289: 109, 1924916: 110, 1930112: 111, 1943899: 112, 1944390: 113, 1945685: 114, 1950731: 115, 1955084: 116, 1968897: 117, 1978287: 118, 1978455: 119, 1980166: 120, 1981276: 121, 1983481: 122, 1984695: 123, 1985128: 124, 1986214: 125, 1990800: 126, 2002556: 127, 2002724: 128, 2006656: 129, 2007558: 130, 2009229: 131, 2009912: 132, 2011460: 133, 2012849: 134, 2013706: 135, 2017213: 136, 2018207: 137, 2018795: 138, 2025239: 139, 2027492: 140, 2028035: 141, 2033041: 142, 2037110: 143, 2051845: 144, 2056570: 145, 2058221: 146, 2066245: 147, 2071294: 148, 2074367: 149, 2077923: 150, 2085620: 151, 2085782: 152, 2085936: 153, 2086079: 154, 2086240: 155, 2086646: 156, 2086910: 157, 2087046: 158, 2087394: 159, 2088094: 160, 2088238: 161, 2088364: 162, 2088466: 163, 2088632: 164, 2089078: 165, 2089867: 166, 2089973: 167, 2090379: 168, 2090622: 169, 2090721: 170, 2091032: 171, 2091134: 172, 2091244: 173, 2091467: 174, 2091635: 175, 2091831: 176, 2092002: 177, 2092339: 178, 2093256: 179, 2093428: 180, 2093647: 181, 2093754: 182, 2093859: 183, 2093991: 184, 2094114: 185, 2094258: 186, 2094433: 187, 2095314: 188, 2095570: 189, 2095889: 190, 2096051: 191, 2096177: 192, 2096294: 193, 2096437: 194, 2096585: 195, 2097047: 196, 2097130: 197, 2097209: 198, 2097298: 199, 2097474: 200, 2097658: 201, 2098105: 202, 2098286: 203, 2098413: 204, 2099267: 205, 2099429: 206, 2099601: 207, 2099712: 208, 2099849: 209, 2100236: 210, 2100583: 211, 2100735: 212, 2100877: 213, 2101006: 214, 2101388: 215, 2101556: 216, 2102040: 217, 2102177: 218, 2102318: 219, 2102480: 220, 2102973: 221, 2104029: 222, 2104365: 223, 2105056: 224, 2105162: 225, 2105251: 226, 2105412: 227, 2105505: 228, 2105641: 229, 2105855: 230, 2106030: 231, 2106166: 232, 2106382: 233, 2106550: 234, 2106662: 235, 2107142: 236, 2107312: 237, 2107574: 238, 2107683: 239, 2107908: 240, 2108000: 241, 2108089: 242, 2108422: 243, 2108551: 244, 2108915: 245, 2109047: 246, 2109525: 247, 2109961: 248, 2110063: 249, 2110185: 250, 2110341: 251, 2110627: 252, 2110806: 253, 2110958: 254, 2111129: 255, 2111277: 256, 2111500: 257, 2111889: 258, 2112018: 259, 2112137: 260, 2112350: 261, 2112706: 262, 2113023: 263, 2113186: 264, 2113624: 265, 2113712: 266, 2113799: 267, 2113978: 268, 2114367: 269, 2114548: 270, 2114712: 271, 2114855: 272, 2115641: 273, 2115913: 274, 2116738: 275, 2117135: 276, 2119022: 277, 2119789: 278, 2120079: 279, 2120505: 280, 2123045: 281, 2123159: 282, 2123394: 283, 2123597: 284, 2124075: 285, 2125311: 286, 2127052: 287, 2128385: 288, 2128757: 289, 2128925: 290, 2129165: 291, 2129604: 292, 2130308: 293, 2132136: 294, 2133161: 295, 2134084: 296, 2134418: 297, 2137549: 298, 2138441: 299, 2165105: 300, 2165456: 301, 2167151: 302, 2168699: 303, 2169497: 304, 2172182: 305, 2174001: 306, 2177972: 307, 2190166: 308, 2206856: 309, 2219486: 310, 2226429: 311, 2229544: 312, 2231487: 313, 2233338: 314, 2236044: 315, 2256656: 316, 2259212: 317, 2264363: 318, 2268443: 319, 2268853: 320, 2276258: 321, 2277742: 322, 2279972: 323, 2280649: 324, 2281406: 325, 2281787: 326, 2317335: 327, 2319095: 328, 2321529: 329, 2325366: 330, 2326432: 331, 2328150: 332, 2342885: 333, 2346627: 334, 2356798: 335, 2361337: 336, 2363005: 337, 2364673: 338, 2389026: 339, 2391049: 340, 2395406: 341, 2396427: 342, 2397096: 343, 2398521: 344, 2403003: 345, 2408429: 346, 2410509: 347, 2412080: 348, 2415577: 349, 2417914: 350, 2422106: 351, 2422699: 352, 2423022: 353, 2437312: 354, 2437616: 355, 2441942: 356, 2442845: 357, 2443114: 358, 2443484: 359, 2444819: 360, 2445715: 361, 2447366: 362, 2454379: 363, 2457408: 364, 2480495: 365, 2480855: 366, 2481823: 367, 2483362: 368, 2483708: 369, 2484975: 370, 2486261: 371, 2486410: 372, 2487347: 373, 2488291: 374, 2488702: 375, 2489166: 376, 2490219: 377, 2492035: 378, 2492660: 379, 2493509: 380, 2493793: 381, 2494079: 382, 2497673: 383, 2500267: 384, 2504013: 385, 2504458: 386, 2509815: 387, 2510455: 388, 2514041: 389, 2526121: 390, 2536864: 391, 2606052: 392, 2607072: 393, 2640242: 394, 2641379: 395, 2643566: 396, 2655020: 397, 2666196: 398, 2667093: 399, 2669723: 400, 2672831: 401, 2676566: 402, 2687172: 403, 2690373: 404, 2692877: 405, 2699494: 406, 2701002: 407, 2704792: 408, 2708093: 409, 2727426: 410, 2730930: 411, 2747177: 412, 2749479: 413, 2769748: 414, 2776631: 415, 2777292: 416, 2782093: 417, 2783161: 418, 2786058: 419, 2787622: 420, 2788148: 421, 2790996: 422, 2791124: 423, 2791270: 424, 2793495: 425, 2794156: 426, 2795169: 427, 2797295: 428, 2799071: 429, 2802426: 430, 2804414: 431, 2804610: 432, 2807133: 433, 2808304: 434, 2808440: 435, 2814533: 436, 2814860: 437, 2815834: 438, 2817516: 439, 2823428: 440, 2823750: 441, 2825657: 442, 2834397: 443, 2835271: 444, 2837789: 445, 2840245: 446, 2841315: 447, 2843684: 448, 2859443: 449, 2860847: 450, 2865351: 451, 2869837: 452, 2870880: 453, 2871525: 454, 2877765: 455, 2879718: 456, 2883205: 457, 2892201: 458, 2892767: 459, 2894605: 460, 2895154: 461, 2906734: 462, 2909870: 463, 2910353: 464, 2916936: 465, 2917067: 466, 2927161: 467, 2930766: 468, 2939185: 469, 2948072: 470, 2950826: 471, 2951358: 472, 2951585: 473, 2963159: 474, 2965783: 475, 2966193: 476, 2966687: 477, 2971356: 478, 2974003: 479, 2977058: 480, 2978881: 481, 2979186: 482, 2980441: 483, 2981792: 484, 2988304: 485, 2992211: 486, 2992529: 487, 2999410: 488, 3000134: 489, 3000247: 490, 3000684: 491, 3014705: 492, 3016953: 493, 3017168: 494, 3018349: 495, 3026506: 496, 3028079: 497, 3032252: 498, 3041632: 499, 3042490: 500, 3045698: 501, 3047690: 502, 3062245: 503, 3063599: 504, 3063689: 505, 3065424: 506, 3075370: 507, 3085013: 508, 3089624: 509, 3095699: 510, 3100240: 511, 3109150: 512, 3110669: 513, 3124043: 514, 3124170: 515, 3125729: 516, 3126707: 517, 3127747: 518, 3127925: 519, 3131574: 520, 3133878: 521, 3134739: 522, 3141823: 523, 3146219: 524, 3160309: 525, 3179701: 526, 3180011: 527, 3187595: 528, 3188531: 529, 3196217: 530, 3197337: 531, 3201208: 532, 3207743: 533, 3207941: 534, 3208938: 535, 3216828: 536, 3218198: 537, 3220513: 538, 3223299: 539, 3240683: 540, 3249569: 541, 3250847: 542, 3255030: 543, 3259280: 544, 3271574: 545, 3272010: 546, 3272562: 547, 3290653: 548, 3291819: 549, 3297495: 550, 3314780: 551, 3325584: 552, 3337140: 553, 3344393: 554, 3345487: 555, 3347037: 556, 3355925: 557, 3372029: 558, 3376595: 559, 3379051: 560, 3384352: 561, 3388043: 562, 3388183: 563, 3388549: 564, 3393912: 565, 3394916: 566, 3400231: 567, 3404251: 568, 3417042: 569, 3424325: 570, 3425413: 571, 3443371: 572, 3444034: 573, 3445777: 574, 3445924: 575, 3447447: 576, 3447721: 577, 3450230: 578, 3452741: 579, 3457902: 580, 3459775: 581, 3461385: 582, 3467068: 583, 3476684: 584, 3476991: 585, 3478589: 586, 3481172: 587, 3482405: 588, 3483316: 589, 3485407: 590, 3485794: 591, 3492542: 592, 3494278: 593, 3495258: 594, 3496892: 595, 3498962: 596, 3527444: 597, 3529860: 598, 3530642: 599, 3532672: 600, 3534580: 601, 3535780: 602, 3538406: 603, 3544143: 604, 3584254: 605, 3584829: 606, 3590841: 607, 3594734: 608, 3594945: 609, 3595614: 610, 3598930: 611, 3599486: 612, 3602883: 613, 3617480: 614, 3623198: 615, 3627232: 616, 3630383: 617, 3633091: 618, 3637318: 619, 3642806: 620, 3649909: 621, 3657121: 622, 3658185: 623, 3661043: 624, 3662601: 625, 3666591: 626, 3670208: 627, 3673027: 628, 3676483: 629, 3680355: 630, 3690938: 631, 3691459: 632, 3692522: 633, 3697007: 634, 3706229: 635, 3709823: 636, 3710193: 637, 3710637: 638, 3710721: 639, 3717622: 640, 3720891: 641, 3721384: 642, 3724870: 643, 3729826: 644, 3733131: 645, 3733281: 646, 3733805: 647, 3742115: 648, 3743016: 649, 3759954: 650, 3761084: 651, 3763968: 652, 3764736: 653, 3769881: 654, 3770439: 655, 3770679: 656, 3773504: 657, 3775071: 658, 3775546: 659, 3776460: 660, 3777568: 661, 3777754: 662, 3781244: 663, 3782006: 664, 3785016: 665, 3786901: 666, 3787032: 667, 3788195: 668, 3788365: 669, 3791053: 670, 3792782: 671, 3792972: 672, 3793489: 673, 3794056: 674, 3796401: 675, 3803284: 676, 3804744: 677, 3814639: 678, 3814906: 679, 3825788: 680, 3832673: 681, 3837869: 682, 3838899: 683, 3840681: 684, 3841143: 685, 3843555: 686, 3854065: 687, 3857828: 688, 3866082: 689, 3868242: 690, 3868863: 691, 3871628: 692, 3873416: 693, 3874293: 694, 3874599: 695, 3876231: 696, 3877472: 697, 3877845: 698, 3884397: 699, 3887697: 700, 3888257: 701, 3888605: 702, 3891251: 703, 3891332: 704, 3895866: 705, 3899768: 706, 3902125: 707, 3903868: 708, 3908618: 709, 3908714: 710, 3916031: 711, 3920288: 712, 3924679: 713, 3929660: 714, 3929855: 715, 3930313: 716, 3930630: 717, 3933933: 718, 3935335: 719, 3937543: 720, 3938244: 721, 3942813: 722, 3944341: 723, 3947888: 724, 3950228: 725, 3954731: 726, 3956157: 727, 3958227: 728, 3961711: 729, 3967562: 730, 3970156: 731, 3976467: 732, 3976657: 733, 3977966: 734, 3980874: 735, 3982430: 736, 3983396: 737, 3991062: 738, 3992509: 739, 3995372: 740, 3998194: 741, 4004767: 742, 4005630: 743, 4008634: 744, 4009552: 745, 4019541: 746, 4023962: 747, 4026417: 748, 4033901: 749, 4033995: 750, 4037443: 751, 4039381: 752, 4040759: 753, 4041544: 754, 4044716: 755, 4049303: 756, 4065272: 757, 4067472: 758, 4069434: 759, 4070727: 760, 4074963: 761, 4081281: 762, 4086273: 763, 4090263: 764, 4099969: 765, 4111531: 766, 4116512: 767, 4118538: 768, 4118776: 769, 4120489: 770, 4125021: 771, 4127249: 772, 4131690: 773, 4133789: 774, 4136333: 775, 4141076: 776, 4141327: 777, 4141975: 778, 4146614: 779, 4147183: 780, 4149813: 781, 4152593: 782, 4153751: 783, 4154565: 784, 4162706: 785, 4179913: 786, 4192698: 787, 4200800: 788, 4201297: 789, 4204238: 790, 4204347: 791, 4208210: 792, 4209133: 793, 4209239: 794, 4228054: 795, 4229816: 796, 4235860: 797, 4238763: 798, 4239074: 799, 4243546: 800, 4251144: 801, 4252077: 802, 4252225: 803, 4254120: 804, 4254680: 805, 4254777: 806, 4258138: 807, 4259630: 808, 4263257: 809, 4264628: 810, 4265275: 811, 4266014: 812, 4270147: 813, 4273569: 814, 4275548: 815, 4277352: 816, 4285008: 817, 4286575: 818, 4296562: 819, 4310018: 820, 4311004: 821, 4311174: 822, 4317175: 823, 4325704: 824, 4326547: 825, 4328186: 826, 4330267: 827, 4332243: 828, 4335435: 829, 4336792: 830, 4344873: 831, 4346328: 832, 4347754: 833, 4350905: 834, 4355338: 835, 4355933: 836, 4356056: 837, 4357314: 838, 4366367: 839, 4367480: 840, 4370456: 841, 4371430: 842, 4371774: 843, 4372370: 844, 4376876: 845, 4380533: 846, 4389033: 847, 4392985: 848, 4398044: 849, 4399382: 850, 4404412: 851, 4409515: 852, 4417672: 853, 4418357: 854, 4423845: 855, 4428191: 856, 4429376: 857, 4435653: 858, 4442312: 859, 4443257: 860, 4447861: 861, 4456115: 862, 4458633: 863, 4461696: 864, 4462240: 865, 4465501: 866, 4467665: 867, 4476259: 868, 4479046: 869, 4482393: 870, 4483307: 871, 4485082: 872, 4486054: 873, 4487081: 874, 4487394: 875, 4493381: 876, 4501370: 877, 4505470: 878, 4507155: 879, 4509417: 880, 4515003: 881, 4517823: 882, 4522168: 883, 4523525: 884, 4525038: 885, 4525305: 886, 4532106: 887, 4532670: 888, 4536866: 889, 4540053: 890, 4542943: 891, 4548280: 892, 4548362: 893, 4550184: 894, 4552348: 895, 4553703: 896, 4554684: 897, 4557648: 898, 4560804: 899, 4562935: 900, 4579145: 901, 4579432: 902, 4584207: 903, 4589890: 904, 4590129: 905, 4591157: 906, 4591713: 907, 4592741: 908, 4596742: 909, 4597913: 910, 4599235: 911, 4604644: 912, 4606251: 913, 4612504: 914, 4613696: 915, 6359193: 916, 6596364: 917, 6785654: 918, 6794110: 919, 6874185: 920, 7248320: 921, 7565083: 922, 7579787: 923, 7583066: 924, 7584110: 925, 7590611: 926, 7613480: 927, 7614500: 928, 7615774: 929, 7684084: 930, 7693725: 931, 7695742: 932, 7697313: 933, 7697537: 934, 7711569: 935, 7714571: 936, 7714990: 937, 7715103: 938, 7716358: 939, 7716906: 940, 7717410: 941, 7717556: 942, 7718472: 943, 7718747: 944, 7720875: 945, 7730033: 946, 7734744: 947, 7742313: 948, 7745940: 949, 7747607: 950, 7749582: 951, 7753113: 952, 7753275: 953, 7753592: 954, 7754684: 955, 7760859: 956, 7768694: 957, 7802026: 958, 7831146: 959, 7836838: 960, 7860988: 961, 7871810: 962, 7873807: 963, 7875152: 964, 7880968: 965, 7892512: 966, 7920052: 967, 7930864: 968, 7932039: 969, 9193705: 970, 9229709: 971, 9246464: 972, 9256479: 973, 9288635: 974, 9332890: 975, 9399592: 976, 9421951: 977, 9428293: 978, 9468604: 979, 9472597: 980, 9835506: 981, 10148035: 982, 10565667: 983, 11879895: 984, 11939491: 985, 12057211: 986, 12144580: 987, 12267677: 988, 12620546: 989, 12768682: 990, 12985857: 991, 12998815: 992, 13037406: 993, 13040303: 994, 13044778: 995, 13052670: 996, 13054560: 997, 13133613: 998, 15075141: 999} diff --git a/models/biggan/pytorch_biggan/requirements.txt b/models/biggan/pytorch_biggan/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f37f11cc540bb1f0f777d9e08a23b9773a8db7c0 --- /dev/null +++ b/models/biggan/pytorch_biggan/requirements.txt @@ -0,0 +1,8 @@ +# PyTorch +torch>=0.4.1 +# progress bars in model download and training scripts +tqdm +# Accessing files from S3 directly. +boto3 +# Used for downloading models over HTTP +requests \ No newline at end of file diff --git a/models/biggan/pytorch_biggan/scripts/convert_tf_hub_models.sh b/models/biggan/pytorch_biggan/scripts/convert_tf_hub_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..caed81a1e9698014ac61e8baa3d98d256cb3b4dd --- /dev/null +++ b/models/biggan/pytorch_biggan/scripts/convert_tf_hub_models.sh @@ -0,0 +1,21 @@ +# Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +set -e +set -x + +models="128 256 512" + +mkdir -p models/model_128 +mkdir -p models/model_256 +mkdir -p models/model_512 + +# Convert TF Hub models. +for model in $models +do + pytorch_pretrained_biggan --model_type $model --tf_model_path models/model_$model --pt_save_path models/model_$model +done diff --git a/models/biggan/pytorch_biggan/scripts/download_tf_hub_models.sh b/models/biggan/pytorch_biggan/scripts/download_tf_hub_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..57655fbd4b77791f03d72b3dfeb3bbb89ccc2fdc --- /dev/null +++ b/models/biggan/pytorch_biggan/scripts/download_tf_hub_models.sh @@ -0,0 +1,21 @@ +# Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +set -e +set -x + +models="128 256 512" + +mkdir -p models/model_128 +mkdir -p models/model_256 +mkdir -p models/model_512 + +# Download TF Hub models. +for model in $models +do + curl -L "https://tfhub.dev/deepmind/biggan-deep-$model/1?tf-hub-format=compressed" | tar -zxvC models/model_$model +done diff --git a/models/biggan/pytorch_biggan/setup.py b/models/biggan/pytorch_biggan/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a34318b6b66f1ca7b15342dea3c23eb904974d6d --- /dev/null +++ b/models/biggan/pytorch_biggan/setup.py @@ -0,0 +1,69 @@ +""" +Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py + +To create the package for pypi. + +1. Change the version in __init__.py and setup.py. + +2. Commit these changes with the message: "Release: VERSION" + +3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " + Push the tag to git: git push --tags origin master + +4. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + + For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. + (this will build a wheel for the python version you use to build it - make sure you use python 3.x). + + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. + +5. Check that everything looks correct by uploading the package to the pypi test server: + + twine upload dist/* -r pypitest + (pypi suggest using twine as other methods upload files via plaintext.) + + Check that you can install it in a virtualenv by running: + pip install -i https://testpypi.python.org/pypi allennlp + +6. Upload the final version to actual pypi: + twine upload dist/* -r pypi + +7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. + +""" +from io import open +from setuptools import find_packages, setup + +setup( + name="pytorch_pretrained_biggan", + version="0.1.0", + author="Thomas Wolf", + author_email="thomas@huggingface.co", + description="PyTorch version of DeepMind's BigGAN model with pre-trained models", + long_description=open("README.md", "r", encoding='utf-8').read(), + long_description_content_type="text/markdown", + keywords='BIGGAN GAN deep learning google deepmind', + license='Apache', + url="https://github.com/huggingface/pytorch-pretrained-BigGAN", + packages=find_packages(exclude=["*.tests", "*.tests.*", + "tests.*", "tests"]), + install_requires=['torch>=0.4.1', + 'numpy', + 'boto3', + 'requests', + 'tqdm'], + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + "pytorch_pretrained_biggan=pytorch_pretrained_biggan.convert_tf_to_pytorch:main", + ] + }, + classifiers=[ + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], +) diff --git a/models/stylegan/__init__.py b/models/stylegan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6edf9b7e860d2b45ed1ccf40223c6fac0b273ab7 --- /dev/null +++ b/models/stylegan/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +from pathlib import Path +import sys + +#module_path = Path(__file__).parent / 'pytorch_biggan' +#sys.path.append(str(module_path.resolve())) + +from .model import StyleGAN_G, NoiseLayer \ No newline at end of file diff --git a/models/stylegan/model.py b/models/stylegan/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a230961c4d1bf0bd2d1efe7972b4baa33c5d7013 --- /dev/null +++ b/models/stylegan/model.py @@ -0,0 +1,456 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict +from pathlib import Path +import requests +import pickle +import sys + +import numpy as np + +# Reimplementation of StyleGAN in PyTorch +# Source: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb + +class MyLinear(nn.Module): + """Linear layer with equalized learning rate and custom learning rate multiplier.""" + def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True): + super().__init__() + he_std = gain * input_size**(-0.5) # He init + # Equalized learning rate and custom learning rate multiplier. + if use_wscale: + init_std = 1.0 / lrmul + self.w_mul = he_std * lrmul + else: + init_std = he_std / lrmul + self.w_mul = lrmul + self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) + if bias: + self.bias = torch.nn.Parameter(torch.zeros(output_size)) + self.b_mul = lrmul + else: + self.bias = None + + def forward(self, x): + bias = self.bias + if bias is not None: + bias = bias * self.b_mul + return F.linear(x, self.weight * self.w_mul, bias) + +class MyConv2d(nn.Module): + """Conv layer with equalized learning rate and custom learning rate multiplier.""" + def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True, + intermediate=None, upscale=False): + super().__init__() + if upscale: + self.upscale = Upscale2d() + else: + self.upscale = None + he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init + self.kernel_size = kernel_size + if use_wscale: + init_std = 1.0 / lrmul + self.w_mul = he_std * lrmul + else: + init_std = he_std / lrmul + self.w_mul = lrmul + self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std) + if bias: + self.bias = torch.nn.Parameter(torch.zeros(output_channels)) + self.b_mul = lrmul + else: + self.bias = None + self.intermediate = intermediate + + def forward(self, x): + bias = self.bias + if bias is not None: + bias = bias * self.b_mul + + have_convolution = False + if self.upscale is not None and min(x.shape[2:]) * 2 >= 128: + # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way + # this really needs to be cleaned up and go into the conv... + w = self.weight * self.w_mul + w = w.permute(1, 0, 2, 3) + # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?! + w = F.pad(w, (1,1,1,1)) + w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] + x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2) + have_convolution = True + elif self.upscale is not None: + x = self.upscale(x) + + if not have_convolution and self.intermediate is None: + return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2) + elif not have_convolution: + x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2) + + if self.intermediate is not None: + x = self.intermediate(x) + if bias is not None: + x = x + bias.view(1, -1, 1, 1) + return x + +class NoiseLayer(nn.Module): + """adds noise. noise is per pixel (constant over channels) with per-channel weight""" + def __init__(self, channels): + super().__init__() + self.weight = nn.Parameter(torch.zeros(channels)) + self.noise = None + + def forward(self, x, noise=None): + if noise is None and self.noise is None: + noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype) + elif noise is None: + # here is a little trick: if you get all the noiselayers and set each + # modules .noise attribute, you can have pre-defined noise. + # Very useful for analysis + noise = self.noise + x = x + self.weight.view(1, -1, 1, 1) * noise + return x + +class StyleMod(nn.Module): + def __init__(self, latent_size, channels, use_wscale): + super(StyleMod, self).__init__() + self.lin = MyLinear(latent_size, + channels * 2, + gain=1.0, use_wscale=use_wscale) + + def forward(self, x, latent): + style = self.lin(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + x = x * (style[:, 0] + 1.) + style[:, 1] + return x + +class PixelNormLayer(nn.Module): + def __init__(self, epsilon=1e-8): + super().__init__() + self.epsilon = epsilon + def forward(self, x): + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) + +class BlurLayer(nn.Module): + def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1): + super(BlurLayer, self).__init__() + kernel=[1, 2, 1] + kernel = torch.tensor(kernel, dtype=torch.float32) + kernel = kernel[:, None] * kernel[None, :] + kernel = kernel[None, None] + if normalize: + kernel = kernel / kernel.sum() + if flip: + kernel = kernel[:, :, ::-1, ::-1] + self.register_buffer('kernel', kernel) + self.stride = stride + + def forward(self, x): + # expand kernel channels + kernel = self.kernel.expand(x.size(1), -1, -1, -1) + x = F.conv2d( + x, + kernel, + stride=self.stride, + padding=int((self.kernel.size(2)-1)/2), + groups=x.size(1) + ) + return x + +def upscale2d(x, factor=2, gain=1): + assert x.dim() == 4 + if gain != 1: + x = x * gain + if factor != 1: + shape = x.shape + x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) + x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) + return x + +class Upscale2d(nn.Module): + def __init__(self, factor=2, gain=1): + super().__init__() + assert isinstance(factor, int) and factor >= 1 + self.gain = gain + self.factor = factor + def forward(self, x): + return upscale2d(x, factor=self.factor, gain=self.gain) + +class G_mapping(nn.Sequential): + def __init__(self, nonlinearity='lrelu', use_wscale=True): + act, gain = {'relu': (torch.relu, np.sqrt(2)), + 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] + layers = [ + ('pixel_norm', PixelNormLayer()), + ('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense0_act', act), + ('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense1_act', act), + ('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense2_act', act), + ('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense3_act', act), + ('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense4_act', act), + ('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense5_act', act), + ('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense6_act', act), + ('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ('dense7_act', act) + ] + super().__init__(OrderedDict(layers)) + + def forward(self, x): + return super().forward(x) + +class Truncation(nn.Module): + def __init__(self, avg_latent, max_layer=8, threshold=0.7): + super().__init__() + self.max_layer = max_layer + self.threshold = threshold + self.register_buffer('avg_latent', avg_latent) + def forward(self, x): + assert x.dim() == 3 + interp = torch.lerp(self.avg_latent, x, self.threshold) + do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1) + return torch.where(do_trunc, interp, x) + +class LayerEpilogue(nn.Module): + """Things to do at the end of each layer.""" + def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): + super().__init__() + layers = [] + if use_noise: + layers.append(('noise', NoiseLayer(channels))) + layers.append(('activation', activation_layer)) + if use_pixel_norm: + layers.append(('pixel_norm', PixelNorm())) + if use_instance_norm: + layers.append(('instance_norm', nn.InstanceNorm2d(channels))) + self.top_epi = nn.Sequential(OrderedDict(layers)) + if use_styles: + self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale) + else: + self.style_mod = None + def forward(self, x, dlatents_in_slice=None): + x = self.top_epi(x) + if self.style_mod is not None: + x = self.style_mod(x, dlatents_in_slice) + else: + assert dlatents_in_slice is None + return x + + +class InputBlock(nn.Module): + def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): + super().__init__() + self.const_input_layer = const_input_layer + self.nf = nf + if self.const_input_layer: + # called 'const' in tf + self.const = nn.Parameter(torch.ones(1, nf, 4, 4)) + self.bias = nn.Parameter(torch.ones(nf)) + else: + self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN + self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) + self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale) + self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) + + def forward(self, dlatents_in_range): + batch_size = dlatents_in_range.size(0) + if self.const_input_layer: + x = self.const.expand(batch_size, -1, -1, -1) + x = x + self.bias.view(1, -1, 1, 1) + else: + x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4) + x = self.epi1(x, dlatents_in_range[:, 0]) + x = self.conv(x) + x = self.epi2(x, dlatents_in_range[:, 1]) + return x + + +class GSynthesisBlock(nn.Module): + def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): + # 2**res x 2**res # res = 3..resolution_log2 + super().__init__() + if blur_filter: + blur = BlurLayer(blur_filter) + else: + blur = None + self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale, + intermediate=blur, upscale=True) + self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) + self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale) + self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) + + def forward(self, x, dlatents_in_range): + x = self.conv0_up(x) + x = self.epi1(x, dlatents_in_range[:, 0]) + x = self.conv1(x) + x = self.epi2(x, dlatents_in_range[:, 1]) + return x + +class G_synthesis(nn.Module): + def __init__(self, + dlatent_size = 512, # Disentangled latent (W) dimensionality. + num_channels = 3, # Number of output color channels. + resolution = 1024, # Output resolution. + fmap_base = 8192, # Overall multiplier for the number of feature maps. + fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. + fmap_max = 512, # Maximum number of feature maps in any layer. + use_styles = True, # Enable style inputs? + const_input_layer = True, # First layer is a learned constant? + use_noise = True, # Enable noise inputs? + randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. + nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu' + use_wscale = True, # Enable equalized learning rate? + use_pixel_norm = False, # Enable pixelwise feature vector normalization? + use_instance_norm = True, # Enable instance normalization? + dtype = torch.float32, # Data type to use for activations and outputs. + blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering. + ): + + super().__init__() + def nf(stage): + return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + self.dlatent_size = dlatent_size + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2**resolution_log2 and resolution >= 4 + + act, gain = {'relu': (torch.relu, np.sqrt(2)), + 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] + num_layers = resolution_log2 * 2 - 2 + num_styles = num_layers if use_styles else 1 + torgbs = [] + blocks = [] + for res in range(2, resolution_log2 + 1): + channels = nf(res-1) + name = '{s}x{s}'.format(s=2**res) + if res == 2: + blocks.append((name, + InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale, + use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) + + else: + blocks.append((name, + GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) + last_channels = channels + self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale) + self.blocks = nn.ModuleDict(OrderedDict(blocks)) + + def forward(self, dlatents_in): + # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size]. + # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype) + batch_size = dlatents_in.size(0) + for i, m in enumerate(self.blocks.values()): + if i == 0: + x = m(dlatents_in[:, 2*i:2*i+2]) + else: + x = m(x, dlatents_in[:, 2*i:2*i+2]) + rgb = self.torgb(x) + return rgb + + +class StyleGAN_G(nn.Sequential): + def __init__(self, resolution, truncation=1.0): + self.resolution = resolution + self.layers = OrderedDict([ + ('g_mapping', G_mapping()), + #('truncation', Truncation(avg_latent)), + ('g_synthesis', G_synthesis(resolution=resolution)), + ]) + super().__init__(self.layers) + + def forward(self, x, latent_is_w=False): + if isinstance(x, list): + assert len(x) == 18, 'Must provide 1 or 18 latents' + if not latent_is_w: + x = [self.layers['g_mapping'].forward(l) for l in x] + x = torch.stack(x, dim=1) + else: + if not latent_is_w: + x = self.layers['g_mapping'].forward(x) + x = x.unsqueeze(1).expand(-1, 18, -1) + + x = self.layers['g_synthesis'].forward(x) + + return x + + # From: https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/ + def load_weights(self, checkpoint): + self.load_state_dict(torch.load(checkpoint)) + + def export_from_tf(self, pickle_path): + module_path = Path(__file__).parent / 'stylegan_tf' + sys.path.append(str(module_path.resolve())) + + import dnnlib, dnnlib.tflib, pickle, torch, collections + dnnlib.tflib.init_tf() + + weights = pickle.load(open(pickle_path,'rb')) + weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights] + #torch.save(weights_pt, pytorch_name) + + # then on the PyTorch side run + state_G, state_D, state_Gs = weights_pt #torch.load('./karras2019stylegan-ffhq-1024x1024.pt') + def key_translate(k): + k = k.lower().split('/') + if k[0] == 'g_synthesis': + if not k[1].startswith('torgb'): + k.insert(1, 'blocks') + k = '.'.join(k) + k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin') + .replace('const.noise.weight','epi1.top_epi.noise.weight') + .replace('conv.noise.weight','epi2.top_epi.noise.weight') + .replace('conv.stylemod','epi2.style_mod.lin') + .replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight') + .replace('conv0_up.stylemod','epi1.style_mod.lin') + .replace('conv1.noise.weight', 'epi2.top_epi.noise.weight') + .replace('conv1.stylemod','epi2.style_mod.lin') + .replace('torgb_lod0','torgb')) + else: + k = '.'.join(k) + return k + + def weight_translate(k, w): + k = key_translate(k) + if k.endswith('.weight'): + if w.dim() == 2: + w = w.t() + elif w.dim() == 1: + pass + else: + assert w.dim() == 4 + w = w.permute(3, 2, 0, 1) + return w + + # we delete the useless torgb filters + param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)} + if 1: + sd_shapes = {k : v.shape for k,v in self.state_dict().items()} + param_shapes = {k : v.shape for k,v in param_dict.items() } + + for k in list(sd_shapes)+list(param_shapes): + pds = param_shapes.get(k) + sds = sd_shapes.get(k) + if pds is None: + print ("sd only", k, sds) + elif sds is None: + print ("pd only", k, pds) + elif sds != pds: + print ("mismatch!", k, pds, sds) + + self.load_state_dict(param_dict, strict=False) # needed for the blur kernels + torch.save(self.state_dict(), Path(pickle_path).with_suffix('.pt')) \ No newline at end of file diff --git a/models/stylegan/stylegan_tf/LICENSE.txt b/models/stylegan/stylegan_tf/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca56419327bbeeb8094330497024f109bd52b96d --- /dev/null +++ b/models/stylegan/stylegan_tf/LICENSE.txt @@ -0,0 +1,410 @@ +Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the "Licensor." The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/models/stylegan/stylegan_tf/README.md b/models/stylegan/stylegan_tf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a86a64a60a14ccea6dc3c0a0048a243750fe98fe --- /dev/null +++ b/models/stylegan/stylegan_tf/README.md @@ -0,0 +1,232 @@ +## StyleGAN — Official TensorFlow Implementation +![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg?style=plastic) +![TensorFlow 1.10](https://img.shields.io/badge/tensorflow-1.10-green.svg?style=plastic) +![cuDNN 7.3.1](https://img.shields.io/badge/cudnn-7.3.1-green.svg?style=plastic) +![License CC BY-NC](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=plastic) + +![Teaser image](./stylegan-teaser.png) +**Picture:** *These people are not real – they were produced by our generator that allows control over different aspects of the image.* + +This repository contains the official TensorFlow implementation of the following paper: + +> **A Style-Based Generator Architecture for Generative Adversarial Networks**
+> Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)
+> https://arxiv.org/abs/1812.04948 +> +> **Abstract:** *We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.* + +For business inquiries, please contact [researchinquiries@nvidia.com](mailto:researchinquiries@nvidia.com)
+For press and other inquiries, please contact Hector Marinez at [hmarinez@nvidia.com](mailto:hmarinez@nvidia.com)
+ +**★★★ NEW: StyleGAN2 is available at [https://github.com/NVlabs/stylegan2](https://github.com/NVlabs/stylegan2) ★★★** + +## Resources + +Material related to our paper is available via the following links: + +- Paper: https://arxiv.org/abs/1812.04948 +- Video: https://youtu.be/kSLJriaOumA +- Code: https://github.com/NVlabs/stylegan +- FFHQ: https://github.com/NVlabs/ffhq-dataset + +Additional material can be found on Google Drive: + +| Path | Description +| :--- | :---------- +| [StyleGAN](https://drive.google.com/open?id=1uka3a1noXHAydRPRbknqwKVGODvnmUBX) | Main folder. +| ├  [stylegan-paper.pdf](https://drive.google.com/open?id=1v-HkF3Ehrpon7wVIx4r5DLcko_U_V6Lt) | High-quality version of the paper PDF. +| ├  [stylegan-video.mp4](https://drive.google.com/open?id=1uzwkZHQX_9pYg1i0d1Nbe3D9xPO8-qBf) | High-quality version of the result video. +| ├  [images](https://drive.google.com/open?id=1-l46akONUWF6LCpDoeq63H53rD7MeiTd) | Example images produced using our generator. +| │  ├  [representative-images](https://drive.google.com/open?id=1ToY5P4Vvf5_c3TyUizQ8fckFFoFtBvD8) | High-quality images to be used in articles, blog posts, etc. +| │  └  [100k-generated-images](https://drive.google.com/open?id=100DJ0QXyG89HZzB4w2Cbyf4xjNK54cQ1) | 100,000 generated images for different amounts of truncation. +| │     ├  [ffhq-1024x1024](https://drive.google.com/open?id=14lm8VRN1pr4g_KVe6_LvyDX1PObst6d4) | Generated using Flickr-Faces-HQ dataset at 1024×1024. +| │     ├  [bedrooms-256x256](https://drive.google.com/open?id=1Vxz9fksw4kgjiHrvHkX4Hze4dyThFW6t) | Generated using LSUN Bedroom dataset at 256×256. +| │     ├  [cars-512x384](https://drive.google.com/open?id=1MFCvOMdLE2_mpeLPTiDw5dxc2CRuKkzS) | Generated using LSUN Car dataset at 512×384. +| │     └  [cats-256x256](https://drive.google.com/open?id=1gq-Gj3GRFiyghTPKhp8uDMA9HV_0ZFWQ) | Generated using LSUN Cat dataset at 256×256. +| ├  [videos](https://drive.google.com/open?id=1N8pOd_Bf8v89NGUaROdbD8-ayLPgyRRo) | Example videos produced using our generator. +| │  └  [high-quality-video-clips](https://drive.google.com/open?id=1NFO7_vH0t98J13ckJYFd7kuaTkyeRJ86) | Individual segments of the result video as high-quality MP4. +| ├  [ffhq-dataset](https://drive.google.com/open?id=1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP) | Raw data for the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset). +| └  [networks](https://drive.google.com/open?id=1MASQyN5m0voPcx7-9K0r5gObhvvPups7) | Pre-trained networks as pickled instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py). +|    ├  [stylegan-ffhq-1024x1024.pkl](https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ) | StyleGAN trained with Flickr-Faces-HQ dataset at 1024×1024. +|    ├  [stylegan-celebahq-1024x1024.pkl](https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf) | StyleGAN trained with CelebA-HQ dataset at 1024×1024. +|    ├  [stylegan-bedrooms-256x256.pkl](https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF) | StyleGAN trained with LSUN Bedroom dataset at 256×256. +|    ├  [stylegan-cars-512x384.pkl](https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3) | StyleGAN trained with LSUN Car dataset at 512×384. +|    ├  [stylegan-cats-256x256.pkl](https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ) | StyleGAN trained with LSUN Cat dataset at 256×256. +|    └  [metrics](https://drive.google.com/open?id=1MvYdWCBuMfnoYGptRH-AgKLbPTsIQLhl) | Auxiliary networks for the quality and disentanglement metrics. +|       ├  [inception_v3_features.pkl](https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn) | Standard [Inception-v3](https://arxiv.org/abs/1512.00567) classifier that outputs a raw feature vector. +|       ├  [vgg16_zhang_perceptual.pkl](https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2) | Standard [LPIPS](https://arxiv.org/abs/1801.03924) metric to estimate perceptual similarity. +|       ├  [celebahq-classifier-00-male.pkl](https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX) | Binary classifier trained to detect a single attribute of CelebA-HQ. +|       └ ⋯ | Please see the file listing for remaining networks. + +## Licenses + +All material, excluding the Flickr-Faces-HQ dataset, is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made. + +For license information regarding the FFHQ dataset, please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset). + +`inception_v3_features.pkl` and `inception_v3_softmax.pkl` are derived from the pre-trained [Inception-v3](https://arxiv.org/abs/1512.00567) network by Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna. The network was originally shared under [Apache 2.0](https://github.com/tensorflow/models/blob/master/LICENSE) license on the [TensorFlow Models](https://github.com/tensorflow/models) repository. + +`vgg16.pkl` and `vgg16_zhang_perceptual.pkl` are derived from the pre-trained [VGG-16](https://arxiv.org/abs/1409.1556) network by Karen Simonyan and Andrew Zisserman. The network was originally shared under [Creative Commons BY 4.0](https://creativecommons.org/licenses/by/4.0/) license on the [Very Deep Convolutional Networks for Large-Scale Visual Recognition](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) project page. + +`vgg16_zhang_perceptual.pkl` is further derived from the pre-trained [LPIPS](https://arxiv.org/abs/1801.03924) weights by Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, and Oliver Wang. The weights were originally shared under [BSD 2-Clause "Simplified" License](https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE) on the [PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity) repository. + +## System requirements + +* Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons. +* 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer. +* TensorFlow 1.10.0 or newer with GPU support. +* One or more high-end NVIDIA GPUs with at least 11GB of DRAM. We recommend NVIDIA DGX-1 with 8 Tesla V100 GPUs. +* NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer. + +## Using pre-trained networks + +A minimal example of using a pre-trained StyleGAN generator is given in [pretrained_example.py](./pretrained_example.py). When executed, the script downloads a pre-trained StyleGAN generator from Google Drive and uses it to generate an image: + +``` +> python pretrained_example.py +Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done + +Gs Params OutputShape WeightShape +--- --- --- --- +latents_in - (?, 512) - +... +images_out - (?, 3, 1024, 1024) - +--- --- --- --- +Total 26219627 + +> ls results +example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP +``` + +A more advanced example is given in [generate_figures.py](./generate_figures.py). The script reproduces the figures from our paper in order to illustrate style mixing, noise inputs, and truncation: +``` +> python generate_figures.py +results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu +results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6 +results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG +results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_ +results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v +results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr +results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke +results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W +``` + +The pre-trained networks are stored as standard pickle files on Google Drive: + +``` +# Load pre-trained network. +url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl +with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: + _G, _D, Gs = pickle.load(f) + # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. + # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. + # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. +``` + +The above code downloads the file and unpickles it to yield 3 instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py). To generate images, you will typically want to use `Gs` – the other two networks are provided for completeness. In order for `pickle.load()` to work, you will need to have the `dnnlib` source directory in your PYTHONPATH and a `tf.Session` set as default. The session can initialized by calling `dnnlib.tflib.init_tf()`. + +There are three ways to use the pre-trained generator: + +1. Use `Gs.run()` for immediate-mode operation where the inputs and outputs are numpy arrays: + ``` + # Pick latent vector. + rnd = np.random.RandomState(5) + latents = rnd.randn(1, Gs.input_shape[1]) + + # Generate image. + fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) + ``` + The first argument is a batch of latent vectors of shape `[num, 512]`. The second argument is reserved for class labels (not used by StyleGAN). The remaining keyword arguments are optional and can be used to further modify the operation (see below). The output is a batch of images, whose format is dictated by the `output_transform` argument. + +2. Use `Gs.get_output_for()` to incorporate the generator as a part of a larger TensorFlow expression: + ``` + latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) + images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) + images = tflib.convert_images_to_uint8(images) + result_expr.append(inception_clone.get_output_for(images)) + ``` + The above code is from [metrics/frechet_inception_distance.py](./metrics/frechet_inception_distance.py). It generates a batch of random images and feeds them directly to the [Inception-v3](https://arxiv.org/abs/1512.00567) network without having to convert the data to numpy arrays in between. + +3. Look up `Gs.components.mapping` and `Gs.components.synthesis` to access individual sub-networks of the generator. Similar to `Gs`, the sub-networks are represented as independent instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py): + ``` + src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) + src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] + src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs) + ``` + The above code is from [generate_figures.py](./generate_figures.py). It first transforms a batch of latent vectors into the intermediate *W* space using the mapping network and then turns these vectors into a batch of images using the synthesis network. The `dlatents` array stores a separate copy of the same *w* vector for each layer of the synthesis network to facilitate style mixing. + +The exact details of the generator are defined in [training/networks_stylegan.py](./training/networks_stylegan.py) (see `G_style`, `G_mapping`, and `G_synthesis`). The following keyword arguments can be specified to modify the behavior when calling `run()` and `get_output_for()`: + +* `truncation_psi` and `truncation_cutoff` control the truncation trick that that is performed by default when using `Gs` (ψ=0.7, cutoff=8). It can be disabled by setting `truncation_psi=1` or `is_validation=True`, and the image quality can be further improved at the cost of variation by setting e.g. `truncation_psi=0.5`. Note that truncation is always disabled when using the sub-networks directly. The average *w* needed to manually perform the truncation trick can be looked up using `Gs.get_var('dlatent_avg')`. + +* `randomize_noise` determines whether to use re-randomize the noise inputs for each generated image (`True`, default) or whether to use specific noise values for the entire minibatch (`False`). The specific values can be accessed via the `tf.Variable` instances that are found using `[var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]`. + +* When using the mapping network directly, you can specify `dlatent_broadcast=None` to disable the automatic duplication of `dlatents` over the layers of the synthesis network. + +* Runtime performance can be fine-tuned via `structure='fixed'` and `dtype='float16'`. The former disables support for progressive growing, which is not needed for a fully-trained generator, and the latter performs all computation using half-precision floating point arithmetic. + +## Preparing datasets for training + +The training and evaluation scripts operate on datasets stored as multi-resolution TFRecords. Each dataset is represented by a directory containing the same image data in several resolutions to enable efficient streaming. There is a separate *.tfrecords file for each resolution, and if the dataset contains labels, they are stored in a separate file as well. By default, the scripts expect to find the datasets at `datasets//-.tfrecords`. The directory can be changed by editing [config.py](./config.py): + +``` +result_dir = 'results' +data_dir = 'datasets' +cache_dir = 'cache' +``` + +To obtain the FFHQ dataset (`datasets/ffhq`), please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset). + +To obtain the CelebA-HQ dataset (`datasets/celebahq`), please refer to the [Progressive GAN repository](https://github.com/tkarras/progressive_growing_of_gans). + +To obtain other datasets, including LSUN, please consult their corresponding project pages. The datasets can be converted to multi-resolution TFRecords using the provided [dataset_tool.py](./dataset_tool.py): + +``` +> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256 +> python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384 +> python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256 +> python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10 +> python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images +``` + +## Training networks + +Once the datasets are set up, you can train your own StyleGAN networks as follows: + +1. Edit [train.py](./train.py) to specify the dataset and training configuration by uncommenting or editing specific lines. +2. Run the training script with `python train.py`. +3. The results are written to a newly created directory `results/-`. +4. The training may take several days (or weeks) to complete, depending on the configuration. + +By default, `train.py` is configured to train the highest-quality StyleGAN (configuration F in Table 1) for the FFHQ dataset at 1024×1024 resolution using 8 GPUs. Please note that we have used 8 GPUs in all of our experiments. Training with fewer GPUs may not produce identical results – if you wish to compare against our technique, we strongly recommend using the same number of GPUs. + +Expected training times for the default configuration using Tesla V100 GPUs: + +| GPUs | 1024×1024 | 512×512 | 256×256 | +| :--- | :-------------- | :------------ | :------------ | +| 1 | 41 days 4 hours | 24 days 21 hours | 14 days 22 hours | +| 2 | 21 days 22 hours | 13 days 7 hours | 9 days 5 hours | +| 4 | 11 days 8 hours | 7 days 0 hours | 4 days 21 hours | +| 8 | 6 days 14 hours | 4 days 10 hours | 3 days 8 hours | + +## Evaluating quality and disentanglement + +The quality and disentanglement metrics used in our paper can be evaluated using [run_metrics.py](./run_metrics.py). By default, the script will evaluate the Fréchet Inception Distance (`fid50k`) for the pre-trained FFHQ generator and write the results into a newly created directory under `results`. The exact behavior can be changed by uncommenting or editing specific lines in [run_metrics.py](./run_metrics.py). + +Expected evaluation time and results for the pre-trained FFHQ generator using one Tesla V100 GPU: + +| Metric | Time | Result | Description +| :----- | :--- | :----- | :---------- +| fid50k | 16 min | 4.4159 | Fréchet Inception Distance using 50,000 images. +| ppl_zfull | 55 min | 664.8854 | Perceptual Path Length for full paths in *Z*. +| ppl_wfull | 55 min | 233.3059 | Perceptual Path Length for full paths in *W*. +| ppl_zend | 55 min | 666.1057 | Perceptual Path Length for path endpoints in *Z*. +| ppl_wend | 55 min | 197.2266 | Perceptual Path Length for path endpoints in *W*. +| ls | 10 hours | z: 165.0106
w: 3.7447 | Linear Separability in *Z* and *W*. + +Please note that the exact results may vary from run to run due to the non-deterministic nature of TensorFlow. + +## Acknowledgements + +We thank Jaakko Lehtinen, David Luebke, and Tuomas Kynkäänniemi for in-depth discussions and helpful comments; Janne Hellsten, Tero Kuosmanen, and Pekka Jänis for compute infrastructure and help with the code release. diff --git a/models/stylegan/stylegan_tf/config.py b/models/stylegan/stylegan_tf/config.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf45253e888806dc58d8dfc994d2dad96527172 --- /dev/null +++ b/models/stylegan/stylegan_tf/config.py @@ -0,0 +1,18 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Global configuration.""" + +#---------------------------------------------------------------------------- +# Paths. + +result_dir = 'results' +data_dir = 'datasets' +cache_dir = 'cache' +run_dir_ignore = ['results', 'datasets', 'cache'] + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/dataset_tool.py b/models/stylegan/stylegan_tf/dataset_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddfe448e2ccaa30e04ad4b49761d406846c962f --- /dev/null +++ b/models/stylegan/stylegan_tf/dataset_tool.py @@ -0,0 +1,645 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.""" + +# pylint: disable=too-many-lines +import os +import sys +import glob +import argparse +import threading +import six.moves.queue as Queue # pylint: disable=import-error +import traceback +import numpy as np +import tensorflow as tf +import PIL.Image +import dnnlib.tflib as tflib + +from training import dataset + +#---------------------------------------------------------------------------- + +def error(msg): + print('Error: ' + msg) + exit(1) + +#---------------------------------------------------------------------------- + +class TFRecordExporter: + def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10): + self.tfrecord_dir = tfrecord_dir + self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) + self.expected_images = expected_images + self.cur_images = 0 + self.shape = None + self.resolution_log2 = None + self.tfr_writers = [] + self.print_progress = print_progress + self.progress_interval = progress_interval + + if self.print_progress: + print('Creating dataset "%s"' % tfrecord_dir) + if not os.path.isdir(self.tfrecord_dir): + os.makedirs(self.tfrecord_dir) + assert os.path.isdir(self.tfrecord_dir) + + def close(self): + if self.print_progress: + print('%-40s\r' % 'Flushing data...', end='', flush=True) + for tfr_writer in self.tfr_writers: + tfr_writer.close() + self.tfr_writers = [] + if self.print_progress: + print('%-40s\r' % '', end='', flush=True) + print('Added %d images.' % self.cur_images) + + def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order. + order = np.arange(self.expected_images) + np.random.RandomState(123).shuffle(order) + return order + + def add_image(self, img): + if self.print_progress and self.cur_images % self.progress_interval == 0: + print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True) + if self.shape is None: + self.shape = img.shape + self.resolution_log2 = int(np.log2(self.shape[1])) + assert self.shape[0] in [1, 3] + assert self.shape[1] == self.shape[2] + assert self.shape[1] == 2**self.resolution_log2 + tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) + for lod in range(self.resolution_log2 - 1): + tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod) + self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) + assert img.shape == self.shape + for lod, tfr_writer in enumerate(self.tfr_writers): + if lod: + img = img.astype(np.float32) + img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25 + quant = np.rint(img).clip(0, 255).astype(np.uint8) + ex = tf.train.Example(features=tf.train.Features(feature={ + 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)), + 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))})) + tfr_writer.write(ex.SerializeToString()) + self.cur_images += 1 + + def add_labels(self, labels): + if self.print_progress: + print('%-40s\r' % 'Saving labels...', end='', flush=True) + assert labels.shape[0] == self.cur_images + with open(self.tfr_prefix + '-rxx.labels', 'wb') as f: + np.save(f, labels.astype(np.float32)) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + +#---------------------------------------------------------------------------- + +class ExceptionInfo(object): + def __init__(self): + self.value = sys.exc_info()[1] + self.traceback = traceback.format_exc() + +#---------------------------------------------------------------------------- + +class WorkerThread(threading.Thread): + def __init__(self, task_queue): + threading.Thread.__init__(self) + self.task_queue = task_queue + + def run(self): + while True: + func, args, result_queue = self.task_queue.get() + if func is None: + break + try: + result = func(*args) + except: + result = ExceptionInfo() + result_queue.put((result, args)) + +#---------------------------------------------------------------------------- + +class ThreadPool(object): + def __init__(self, num_threads): + assert num_threads >= 1 + self.task_queue = Queue.Queue() + self.result_queues = dict() + self.num_threads = num_threads + for _idx in range(self.num_threads): + thread = WorkerThread(self.task_queue) + thread.daemon = True + thread.start() + + def add_task(self, func, args=()): + assert hasattr(func, '__call__') # must be a function + if func not in self.result_queues: + self.result_queues[func] = Queue.Queue() + self.task_queue.put((func, args, self.result_queues[func])) + + def get_result(self, func): # returns (result, args) + result, args = self.result_queues[func].get() + if isinstance(result, ExceptionInfo): + print('\n\nWorker thread caught an exception:\n' + result.traceback) + raise result.value + return result, args + + def finish(self): + for _idx in range(self.num_threads): + self.task_queue.put((None, (), None)) + + def __enter__(self): # for 'with' statement + return self + + def __exit__(self, *excinfo): + self.finish() + + def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None): + if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4 + assert max_items_in_flight >= 1 + results = [] + retire_idx = [0] + + def task_func(prepared, _idx): + return process_func(prepared) + + def retire_result(): + processed, (_prepared, idx) = self.get_result(task_func) + results[idx] = processed + while retire_idx[0] < len(results) and results[retire_idx[0]] is not None: + yield post_func(results[retire_idx[0]]) + results[retire_idx[0]] = None + retire_idx[0] += 1 + + for idx, item in enumerate(item_iterator): + prepared = pre_func(item) + results.append(None) + self.add_task(func=task_func, args=(prepared, idx)) + while retire_idx[0] < idx - max_items_in_flight + 2: + for res in retire_result(): yield res + while retire_idx[0] < len(results): + for res in retire_result(): yield res + +#---------------------------------------------------------------------------- + +def display(tfrecord_dir): + print('Loading dataset "%s"' % tfrecord_dir) + tflib.init_tf({'gpu_options.allow_growth': True}) + dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0) + tflib.init_uninitialized_vars() + import cv2 # pip install opencv-python + + idx = 0 + while True: + try: + images, labels = dset.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + break + if idx == 0: + print('Displaying images') + cv2.namedWindow('dataset_tool') + print('Press SPACE or ENTER to advance, ESC to exit') + print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist())) + cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR + idx += 1 + if cv2.waitKey() == 27: + break + print('\nDisplayed %d images.' % idx) + +#---------------------------------------------------------------------------- + +def extract(tfrecord_dir, output_dir): + print('Loading dataset "%s"' % tfrecord_dir) + tflib.init_tf({'gpu_options.allow_growth': True}) + dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0) + tflib.init_uninitialized_vars() + + print('Extracting images to "%s"' % output_dir) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + idx = 0 + while True: + if idx % 10 == 0: + print('%d\r' % idx, end='', flush=True) + try: + images, _labels = dset.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + break + if images.shape[1] == 1: + img = PIL.Image.fromarray(images[0][0], 'L') + else: + img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB') + img.save(os.path.join(output_dir, 'img%08d.png' % idx)) + idx += 1 + print('Extracted %d images.' % idx) + +#---------------------------------------------------------------------------- + +def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels): + max_label_size = 0 if ignore_labels else 'full' + print('Loading dataset "%s"' % tfrecord_dir_a) + tflib.init_tf({'gpu_options.allow_growth': True}) + dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0) + print('Loading dataset "%s"' % tfrecord_dir_b) + dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0) + tflib.init_uninitialized_vars() + + print('Comparing datasets') + idx = 0 + identical_images = 0 + identical_labels = 0 + while True: + if idx % 100 == 0: + print('%d\r' % idx, end='', flush=True) + try: + images_a, labels_a = dset_a.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + images_a, labels_a = None, None + try: + images_b, labels_b = dset_b.get_minibatch_np(1) + except tf.errors.OutOfRangeError: + images_b, labels_b = None, None + if images_a is None or images_b is None: + if images_a is not None or images_b is not None: + print('Datasets contain different number of images') + break + if images_a.shape == images_b.shape and np.all(images_a == images_b): + identical_images += 1 + else: + print('Image %d is different' % idx) + if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b): + identical_labels += 1 + else: + print('Label %d is different' % idx) + idx += 1 + print('Identical images: %d / %d' % (identical_images, idx)) + if not ignore_labels: + print('Identical labels: %d / %d' % (identical_labels, idx)) + +#---------------------------------------------------------------------------- + +def create_mnist(tfrecord_dir, mnist_dir): + print('Loading MNIST from "%s"' % mnist_dir) + import gzip + with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: + images = np.frombuffer(file.read(), np.uint8, offset=16) + with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file: + labels = np.frombuffer(file.read(), np.uint8, offset=8) + images = images.reshape(-1, 1, 28, 28) + images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (60000,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123): + print('Loading MNIST from "%s"' % mnist_dir) + import gzip + with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: + images = np.frombuffer(file.read(), np.uint8, offset=16) + images = images.reshape(-1, 28, 28) + images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) + assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + + with TFRecordExporter(tfrecord_dir, num_images) as tfr: + rnd = np.random.RandomState(random_seed) + for _idx in range(num_images): + tfr.add_image(images[rnd.randint(images.shape[0], size=3)]) + +#---------------------------------------------------------------------------- + +def create_cifar10(tfrecord_dir, cifar10_dir): + print('Loading CIFAR-10 from "%s"' % cifar10_dir) + import pickle + images = [] + labels = [] + for batch in range(1, 6): + with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file: + data = pickle.load(file, encoding='latin1') + images.append(data['data'].reshape(-1, 3, 32, 32)) + labels.append(data['labels']) + images = np.concatenate(images) + labels = np.concatenate(labels) + assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype == np.int32 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_cifar100(tfrecord_dir, cifar100_dir): + print('Loading CIFAR-100 from "%s"' % cifar100_dir) + import pickle + with open(os.path.join(cifar100_dir, 'train'), 'rb') as file: + data = pickle.load(file, encoding='latin1') + images = data['data'].reshape(-1, 3, 32, 32) + labels = np.array(data['fine_labels']) + assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (50000,) and labels.dtype == np.int32 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 99 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_svhn(tfrecord_dir, svhn_dir): + print('Loading SVHN from "%s"' % svhn_dir) + import pickle + images = [] + labels = [] + for batch in range(1, 4): + with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file: + data = pickle.load(file, encoding='latin1') + images.append(data[0]) + labels.append(data[1]) + images = np.concatenate(images) + labels = np.concatenate(labels) + assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8 + assert labels.shape == (73257,) and labels.dtype == np.uint8 + assert np.min(images) == 0 and np.max(images) == 255 + assert np.min(labels) == 0 and np.max(labels) == 9 + onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) + onehot[np.arange(labels.size), labels] = 1.0 + + with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + tfr.add_image(images[order[idx]]) + tfr.add_labels(onehot[order]) + +#---------------------------------------------------------------------------- + +def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None): + print('Loading LSUN dataset from "%s"' % lmdb_dir) + import lmdb # pip install lmdb # pylint: disable=import-error + import cv2 # pip install opencv-python + import io + with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: + total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter + if max_images is None: + max_images = total_images + with TFRecordExporter(tfrecord_dir, max_images) as tfr: + for _idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) + if img is None: + raise IOError('cv2.imdecode failed') + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.asarray(PIL.Image.open(io.BytesIO(value))) + crop = np.min(img.shape[:2]) + img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS) + img = np.asarray(img) + img = img.transpose([2, 0, 1]) # HWC => CHW + tfr.add_image(img) + except: + print(sys.exc_info()[1]) + if tfr.cur_images == max_images: + break + +#---------------------------------------------------------------------------- + +def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None): + assert width == 2 ** int(np.round(np.log2(width))) + assert height <= width + print('Loading LSUN dataset from "%s"' % lmdb_dir) + import lmdb # pip install lmdb # pylint: disable=import-error + import cv2 # pip install opencv-python + import io + with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: + total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter + if max_images is None: + max_images = total_images + with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr: + for idx, (_key, value) in enumerate(txn.cursor()): + try: + try: + img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) + if img is None: + raise IOError('cv2.imdecode failed') + img = img[:, :, ::-1] # BGR => RGB + except IOError: + img = np.asarray(PIL.Image.open(io.BytesIO(value))) + + ch = int(np.round(width * img.shape[0] / img.shape[1])) + if img.shape[1] < width or ch < height: + continue + + img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] + img = PIL.Image.fromarray(img, 'RGB') + img = img.resize((width, height), PIL.Image.ANTIALIAS) + img = np.asarray(img) + img = img.transpose([2, 0, 1]) # HWC => CHW + + canvas = np.zeros([3, width, width], dtype=np.uint8) + canvas[:, (width - height) // 2 : (width + height) // 2] = img + tfr.add_image(canvas) + print('\r%d / %d => %d ' % (idx + 1, total_images, tfr.cur_images), end='') + + except: + print(sys.exc_info()[1]) + if tfr.cur_images == max_images: + break + print() + +#---------------------------------------------------------------------------- + +def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121): + print('Loading CelebA from "%s"' % celeba_dir) + glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png') + image_filenames = sorted(glob.glob(glob_pattern)) + expected_images = 202599 + if len(image_filenames) != expected_images: + error('Expected to find %d images' % expected_images) + + with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: + order = tfr.choose_shuffled_order() + for idx in range(order.size): + img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) + assert img.shape == (218, 178, 3) + img = img[cy - 64 : cy + 64, cx - 64 : cx + 64] + img = img.transpose(2, 0, 1) # HWC => CHW + tfr.add_image(img) + +#---------------------------------------------------------------------------- + +def create_from_images(tfrecord_dir, image_dir, shuffle): + print('Loading images from "%s"' % image_dir) + image_filenames = sorted(glob.glob(os.path.join(image_dir, '*'))) + if len(image_filenames) == 0: + error('No input images found') + + img = np.asarray(PIL.Image.open(image_filenames[0])) + resolution = img.shape[0] + channels = img.shape[2] if img.ndim == 3 else 1 + if img.shape[1] != resolution: + error('Input images must have the same width and height') + if resolution != 2 ** int(np.floor(np.log2(resolution))): + error('Input image resolution must be a power-of-two') + if channels not in [1, 3]: + error('Input images must be stored as RGB or grayscale') + + with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr: + order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames)) + for idx in range(order.size): + img = np.asarray(PIL.Image.open(image_filenames[order[idx]])) + if channels == 1: + img = img[np.newaxis, :, :] # HW => CHW + else: + img = img.transpose([2, 0, 1]) # HWC => CHW + tfr.add_image(img) + +#---------------------------------------------------------------------------- + +def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle): + print('Loading HDF5 archive from "%s"' % hdf5_filename) + import h5py # conda install h5py + with h5py.File(hdf5_filename, 'r') as hdf5_file: + hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3]) + with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr: + order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0]) + for idx in range(order.size): + tfr.add_image(hdf5_data[order[idx]]) + npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy' + if os.path.isfile(npy_filename): + tfr.add_labels(np.load(npy_filename)[order]) + +#---------------------------------------------------------------------------- + +def execute_cmdline(argv): + prog = argv[0] + parser = argparse.ArgumentParser( + prog = prog, + description = 'Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.', + epilog = 'Type "%s -h" for more information.' % prog) + + subparsers = parser.add_subparsers(dest='command') + subparsers.required = True + def add_command(cmd, desc, example=None): + epilog = 'Example: %s %s' % (prog, example) if example is not None else None + return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog) + + p = add_command( 'display', 'Display images in dataset.', + 'display datasets/mnist') + p.add_argument( 'tfrecord_dir', help='Directory containing dataset') + + p = add_command( 'extract', 'Extract images from dataset.', + 'extract datasets/mnist mnist-images') + p.add_argument( 'tfrecord_dir', help='Directory containing dataset') + p.add_argument( 'output_dir', help='Directory to extract the images into') + + p = add_command( 'compare', 'Compare two datasets.', + 'compare datasets/mydataset datasets/mnist') + p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset') + p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset') + p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0) + + p = add_command( 'create_mnist', 'Create dataset for MNIST.', + 'create_mnist datasets/mnist ~/downloads/mnist') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'mnist_dir', help='Directory containing MNIST') + + p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.', + 'create_mnistrgb datasets/mnistrgb ~/downloads/mnist') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'mnist_dir', help='Directory containing MNIST') + p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000) + p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123) + + p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.', + 'create_cifar10 datasets/cifar10 ~/downloads/cifar10') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10') + + p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.', + 'create_cifar100 datasets/cifar100 ~/downloads/cifar100') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100') + + p = add_command( 'create_svhn', 'Create dataset for SVHN.', + 'create_svhn datasets/svhn ~/downloads/svhn') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'svhn_dir', help='Directory containing SVHN') + + p = add_command( 'create_lsun', 'Create dataset for single LSUN category.', + 'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'lmdb_dir', help='Directory containing LMDB database') + p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256) + p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) + + p = add_command( 'create_lsun_wide', 'Create LSUN dataset with non-square aspect ratio.', + 'create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'lmdb_dir', help='Directory containing LMDB database') + p.add_argument( '--width', help='Output width (default: 512)', type=int, default=512) + p.add_argument( '--height', help='Output height (default: 384)', type=int, default=384) + p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) + + p = add_command( 'create_celeba', 'Create dataset for CelebA.', + 'create_celeba datasets/celeba ~/downloads/celeba') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'celeba_dir', help='Directory containing CelebA') + p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89) + p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121) + + p = add_command( 'create_from_images', 'Create dataset from a directory full of images.', + 'create_from_images datasets/mydataset myimagedir') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'image_dir', help='Directory containing the images') + p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) + + p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.', + 'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5') + p.add_argument( 'tfrecord_dir', help='New dataset directory to be created') + p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images') + p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1) + + args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h']) + func = globals()[args.command] + del args.command + func(**vars(args)) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + execute_cmdline(sys.argv) + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/dnnlib/__init__.py b/models/stylegan/stylegan_tf/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad43827d8a279c4a797e09b51b8fd96e8e003ee6 --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +from . import submission + +from .submission.run_context import RunContext + +from .submission.submit import SubmitTarget +from .submission.submit import PathType +from .submission.submit import SubmitConfig +from .submission.submit import get_path_from_template +from .submission.submit import submit_run + +from .util import EasyDict + +submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. diff --git a/models/stylegan/stylegan_tf/dnnlib/submission/__init__.py b/models/stylegan/stylegan_tf/dnnlib/submission/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53856121d673459ae2b21ecef3d0fcb12a12cdfe --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/submission/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +from . import run_context +from . import submit diff --git a/models/stylegan/stylegan_tf/dnnlib/submission/_internal/run.py b/models/stylegan/stylegan_tf/dnnlib/submission/_internal/run.py new file mode 100644 index 0000000000000000000000000000000000000000..18f830d81ead15fece09382cc30654fb89d14d1b --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/submission/_internal/run.py @@ -0,0 +1,45 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper for launching run functions in computing clusters. + +During the submit process, this file is copied to the appropriate run dir. +When the job is launched in the cluster, this module is the first thing that +is run inside the docker container. +""" + +import os +import pickle +import sys + +# PYTHONPATH should have been set so that the run_dir/src is in it +import dnnlib + +def main(): + if not len(sys.argv) >= 4: + raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") + + run_dir = str(sys.argv[1]) + task_name = str(sys.argv[2]) + host_name = str(sys.argv[3]) + + submit_config_path = os.path.join(run_dir, "submit_config.pkl") + + # SubmitConfig should have been pickled to the run dir + if not os.path.exists(submit_config_path): + raise RuntimeError("SubmitConfig pickle file does not exist!") + + submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) + dnnlib.submission.submit.set_user_name_override(submit_config.user_name) + + submit_config.task_name = task_name + submit_config.host_name = host_name + + dnnlib.submission.submit.run_wrapper(submit_config) + +if __name__ == "__main__": + main() diff --git a/models/stylegan/stylegan_tf/dnnlib/submission/run_context.py b/models/stylegan/stylegan_tf/dnnlib/submission/run_context.py new file mode 100644 index 0000000000000000000000000000000000000000..932320e4735bde1b547ac6062b175601b7959547 --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/submission/run_context.py @@ -0,0 +1,99 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helpers for managing the run/training loop.""" + +import datetime +import json +import os +import pprint +import time +import types + +from typing import Any + +from . import submit + + +class RunContext(object): + """Helper class for managing the run/training loop. + + The context will hide the implementation details of a basic run/training loop. + It will set things up properly, tell if run should be stopped, and then cleans up. + User should call update periodically and use should_stop to determine if run should be stopped. + + Args: + submit_config: The SubmitConfig that is used for the current run. + config_module: The whole config module that is used for the current run. + max_epoch: Optional cached value for the max_epoch variable used in update. + """ + + def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): + self.submit_config = submit_config + self.should_stop_flag = False + self.has_closed = False + self.start_time = time.time() + self.last_update_time = time.time() + self.last_update_interval = 0.0 + self.max_epoch = max_epoch + + # pretty print the all the relevant content of the config module to a text file + if config_module is not None: + with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: + filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} + pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) + + # write out details about the run to a text file + self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} + with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: + pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) + + def __enter__(self) -> "RunContext": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: + """Do general housekeeping and keep the state of the context up-to-date. + Should be called often enough but not in a tight loop.""" + assert not self.has_closed + + self.last_update_interval = time.time() - self.last_update_time + self.last_update_time = time.time() + + if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): + self.should_stop_flag = True + + max_epoch_val = self.max_epoch if max_epoch is None else max_epoch + + def should_stop(self) -> bool: + """Tell whether a stopping condition has been triggered one way or another.""" + return self.should_stop_flag + + def get_time_since_start(self) -> float: + """How much time has passed since the creation of the context.""" + return time.time() - self.start_time + + def get_time_since_last_update(self) -> float: + """How much time has passed since the last call to update.""" + return time.time() - self.last_update_time + + def get_last_update_interval(self) -> float: + """How much time passed between the previous two calls to update.""" + return self.last_update_interval + + def close(self) -> None: + """Close the context and clean up. + Should only be called once.""" + if not self.has_closed: + # update the run.txt with stopping time + self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") + with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: + pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) + + self.has_closed = True diff --git a/models/stylegan/stylegan_tf/dnnlib/submission/submit.py b/models/stylegan/stylegan_tf/dnnlib/submission/submit.py new file mode 100644 index 0000000000000000000000000000000000000000..60ff428717c13896bb78625b3eaf651d9fb9695d --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/submission/submit.py @@ -0,0 +1,290 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Submit a function to be run either locally or in a computing cluster.""" + +import copy +import io +import os +import pathlib +import pickle +import platform +import pprint +import re +import shutil +import time +import traceback + +import zipfile + +from enum import Enum + +from .. import util +from ..util import EasyDict + + +class SubmitTarget(Enum): + """The target where the function should be run. + + LOCAL: Run it locally. + """ + LOCAL = 1 + + +class PathType(Enum): + """Determines in which format should a path be formatted. + + WINDOWS: Format with Windows style. + LINUX: Format with Linux/Posix style. + AUTO: Use current OS type to select either WINDOWS or LINUX. + """ + WINDOWS = 1 + LINUX = 2 + AUTO = 3 + + +_user_name_override = None + + +class SubmitConfig(util.EasyDict): + """Strongly typed config dict needed to submit runs. + + Attributes: + run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template. + run_desc: Description of the run. Will be used in the run dir and task name. + run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir. + run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir. + submit_target: Submit target enum value. Used to select where the run is actually launched. + num_gpus: Number of GPUs used/requested for the run. + print_info: Whether to print debug information when submitting. + ask_confirmation: Whether to ask a confirmation before submitting. + run_id: Automatically populated value during submit. + run_name: Automatically populated value during submit. + run_dir: Automatically populated value during submit. + run_func_name: Automatically populated value during submit. + run_func_kwargs: Automatically populated value during submit. + user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value. + task_name: Automatically populated value during submit. + host_name: Automatically populated value during submit. + """ + + def __init__(self): + super().__init__() + + # run (set these) + self.run_dir_root = "" # should always be passed through get_path_from_template + self.run_desc = "" + self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"] + self.run_dir_extra_files = None + + # submit (set these) + self.submit_target = SubmitTarget.LOCAL + self.num_gpus = 1 + self.print_info = False + self.ask_confirmation = False + + # (automatically populated) + self.run_id = None + self.run_name = None + self.run_dir = None + self.run_func_name = None + self.run_func_kwargs = None + self.user_name = None + self.task_name = None + self.host_name = "localhost" + + +def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str: + """Replace tags in the given path template and return either Windows or Linux formatted path.""" + # automatically select path type depending on running OS + if path_type == PathType.AUTO: + if platform.system() == "Windows": + path_type = PathType.WINDOWS + elif platform.system() == "Linux": + path_type = PathType.LINUX + else: + raise RuntimeError("Unknown platform") + + path_template = path_template.replace("", get_user_name()) + + # return correctly formatted path + if path_type == PathType.WINDOWS: + return str(pathlib.PureWindowsPath(path_template)) + elif path_type == PathType.LINUX: + return str(pathlib.PurePosixPath(path_template)) + else: + raise RuntimeError("Unknown platform") + + +def get_template_from_path(path: str) -> str: + """Convert a normal path back to its template representation.""" + # replace all path parts with the template tags + path = path.replace("\\", "/") + return path + + +def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str: + """Convert a normal path to template and the convert it back to a normal path with given path type.""" + path_template = get_template_from_path(path) + path = get_path_from_template(path_template, path_type) + return path + + +def set_user_name_override(name: str) -> None: + """Set the global username override value.""" + global _user_name_override + _user_name_override = name + + +def get_user_name(): + """Get the current user name.""" + if _user_name_override is not None: + return _user_name_override + elif platform.system() == "Windows": + return os.getlogin() + elif platform.system() == "Linux": + try: + import pwd # pylint: disable=import-error + return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member + except: + return "unknown" + else: + raise RuntimeError("Unknown platform") + + +def _create_run_dir_local(submit_config: SubmitConfig) -> str: + """Create a new run dir with increasing ID number at the start.""" + run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO) + + if not os.path.exists(run_dir_root): + print("Creating the run dir root: {}".format(run_dir_root)) + os.makedirs(run_dir_root) + + submit_config.run_id = _get_next_run_id_local(run_dir_root) + submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc) + run_dir = os.path.join(run_dir_root, submit_config.run_name) + + if os.path.exists(run_dir): + raise RuntimeError("The run dir already exists! ({0})".format(run_dir)) + + print("Creating the run dir: {}".format(run_dir)) + os.makedirs(run_dir) + + return run_dir + + +def _get_next_run_id_local(run_dir_root: str) -> int: + """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.""" + dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))] + r = re.compile("^\\d+") # match one or more digits at the start of the string + run_id = 0 + + for dir_name in dir_names: + m = r.match(dir_name) + + if m is not None: + i = int(m.group()) + run_id = max(run_id, i + 1) + + return run_id + + +def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None: + """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.""" + print("Copying files to the run dir") + files = [] + + run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name) + assert '.' in submit_config.run_func_name + for _idx in range(submit_config.run_func_name.count('.') - 1): + run_func_module_dir_path = os.path.dirname(run_func_module_dir_path) + files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False) + + dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib") + files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True) + + if submit_config.run_dir_extra_files is not None: + files += submit_config.run_dir_extra_files + + files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files] + files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))] + + util.copy_files_and_create_dirs(files) + + pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb")) + + with open(os.path.join(run_dir, "submit_config.txt"), "w") as f: + pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False) + + +def run_wrapper(submit_config: SubmitConfig) -> None: + """Wrap the actual run function call for handling logging, exceptions, typing, etc.""" + is_local = submit_config.submit_target == SubmitTarget.LOCAL + + checker = None + + # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing + if is_local: + logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True) + else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh) + logger = util.Logger(file_name=None, should_flush=True) + + import dnnlib + dnnlib.submit_config = submit_config + + try: + print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name)) + start_time = time.time() + util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs) + print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time))) + except: + if is_local: + raise + else: + traceback.print_exc() + + log_src = os.path.join(submit_config.run_dir, "log.txt") + log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name)) + shutil.copyfile(log_src, log_dst) + finally: + open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close() + + dnnlib.submit_config = None + logger.close() + + if checker is not None: + checker.stop() + + +def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None: + """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.""" + submit_config = copy.copy(submit_config) + + if submit_config.user_name is None: + submit_config.user_name = get_user_name() + + submit_config.run_func_name = run_func_name + submit_config.run_func_kwargs = run_func_kwargs + + assert submit_config.submit_target == SubmitTarget.LOCAL + if submit_config.submit_target in {SubmitTarget.LOCAL}: + run_dir = _create_run_dir_local(submit_config) + + submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc) + submit_config.run_dir = run_dir + _populate_run_dir(run_dir, submit_config) + + if submit_config.print_info: + print("\nSubmit config:\n") + pprint.pprint(submit_config, indent=4, width=200, compact=False) + print() + + if submit_config.ask_confirmation: + if not util.ask_yes_no("Continue submitting the job?"): + return + + run_wrapper(submit_config) diff --git a/models/stylegan/stylegan_tf/dnnlib/tflib/__init__.py b/models/stylegan/stylegan_tf/dnnlib/tflib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f054a39cb81e38ca8b1f4ad5bac168aa68e7d92e --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/tflib/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +from . import autosummary +from . import network +from . import optimizer +from . import tfutil + +from .tfutil import * +from .network import Network + +from .optimizer import Optimizer diff --git a/models/stylegan/stylegan_tf/dnnlib/tflib/autosummary.py b/models/stylegan/stylegan_tf/dnnlib/tflib/autosummary.py new file mode 100644 index 0000000000000000000000000000000000000000..43154f792e5ebe15ee6045a5acdfb279cebefcaa --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/tflib/autosummary.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper for adding automatically tracked values to Tensorboard. + +Autosummary creates an identity op that internally keeps track of the input +values and automatically shows up in TensorBoard. The reported value +represents an average over input components. The average is accumulated +constantly over time and flushed when save_summaries() is called. + +Notes: +- The output tensor must be used as an input for something else in the + graph. Otherwise, the autosummary op will not get executed, and the average + value will not get accumulated. +- It is perfectly fine to include autosummaries with the same name in + several places throughout the graph, even if they are executed concurrently. +- It is ok to also pass in a python scalar or numpy array. In this case, it + is added to the average immediately. +""" + +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorboard import summary as summary_lib +from tensorboard.plugins.custom_scalar import layout_pb2 + +from . import tfutil +from .tfutil import TfExpression +from .tfutil import TfExpressionEx + +_dtype = tf.float64 +_vars = OrderedDict() # name => [var, ...] +_immediate = OrderedDict() # name => update_op, update_value +_finalized = False +_merge_op = None + + +def _create_var(name: str, value_expr: TfExpression) -> TfExpression: + """Internal helper for creating autosummary accumulators.""" + assert not _finalized + name_id = name.replace("/", "_") + v = tf.cast(value_expr, _dtype) + + if v.shape.is_fully_defined(): + size = np.prod(tfutil.shape_to_list(v.shape)) + size_expr = tf.constant(size, dtype=_dtype) + else: + size = None + size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) + + if size == 1: + if v.shape.ndims != 0: + v = tf.reshape(v, []) + v = [size_expr, v, tf.square(v)] + else: + v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] + v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) + + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): + var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] + update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) + + if name in _vars: + _vars[name].append(var) + else: + _vars[name] = [var] + return update_op + + +def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: + """Create a new autosummary. + + Args: + name: Name to use in TensorBoard + value: TensorFlow expression or python value to track + passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. + + Example use of the passthru mechanism: + + n = autosummary('l2loss', loss, passthru=n) + + This is a shorthand for the following code: + + with tf.control_dependencies([autosummary('l2loss', loss)]): + n = tf.identity(n) + """ + tfutil.assert_tf_initialized() + name_id = name.replace("/", "_") + + if tfutil.is_tf_expression(value): + with tf.name_scope("summary_" + name_id), tf.device(value.device): + update_op = _create_var(name, value) + with tf.control_dependencies([update_op]): + return tf.identity(value if passthru is None else passthru) + + else: # python scalar or numpy array + if name not in _immediate: + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): + update_value = tf.placeholder(_dtype) + update_op = _create_var(name, update_value) + _immediate[name] = update_op, update_value + + update_op, update_value = _immediate[name] + tfutil.run(update_op, {update_value: value}) + return value if passthru is None else passthru + + +def finalize_autosummaries() -> None: + """Create the necessary ops to include autosummaries in TensorBoard report. + Note: This should be done only once per graph. + """ + global _finalized + tfutil.assert_tf_initialized() + + if _finalized: + return None + + _finalized = True + tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) + + # Create summary ops. + with tf.device(None), tf.control_dependencies(None): + for name, vars_list in _vars.items(): + name_id = name.replace("/", "_") + with tfutil.absolute_name_scope("Autosummary/" + name_id): + moments = tf.add_n(vars_list) + moments /= moments[0] + with tf.control_dependencies([moments]): # read before resetting + reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] + with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting + mean = moments[1] + std = tf.sqrt(moments[2] - tf.square(moments[1])) + tf.summary.scalar(name, mean) + tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) + tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) + + # Group by category and chart name. + cat_dict = OrderedDict() + for series_name in sorted(_vars.keys()): + p = series_name.split("/") + cat = p[0] if len(p) >= 2 else "" + chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] + if cat not in cat_dict: + cat_dict[cat] = OrderedDict() + if chart not in cat_dict[cat]: + cat_dict[cat][chart] = [] + cat_dict[cat][chart].append(series_name) + + # Setup custom_scalar layout. + categories = [] + for cat_name, chart_dict in cat_dict.items(): + charts = [] + for chart_name, series_names in chart_dict.items(): + series = [] + for series_name in series_names: + series.append(layout_pb2.MarginChartContent.Series( + value=series_name, + lower="xCustomScalars/" + series_name + "/margin_lo", + upper="xCustomScalars/" + series_name + "/margin_hi")) + margin = layout_pb2.MarginChartContent(series=series) + charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) + categories.append(layout_pb2.Category(title=cat_name, chart=charts)) + layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) + return layout + +def save_summaries(file_writer, global_step=None): + """Call FileWriter.add_summary() with all summaries in the default graph, + automatically finalizing and merging them on the first call. + """ + global _merge_op + tfutil.assert_tf_initialized() + + if _merge_op is None: + layout = finalize_autosummaries() + if layout is not None: + file_writer.add_summary(layout) + with tf.device(None), tf.control_dependencies(None): + _merge_op = tf.summary.merge_all() + + file_writer.add_summary(_merge_op.eval(), global_step) diff --git a/models/stylegan/stylegan_tf/dnnlib/tflib/network.py b/models/stylegan/stylegan_tf/dnnlib/tflib/network.py new file mode 100644 index 0000000000000000000000000000000000000000..d888a90dd23c1a941b5fb501afec1efcb763b5ea --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/tflib/network.py @@ -0,0 +1,591 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper for managing networks.""" + +import types +import inspect +import re +import uuid +import sys +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import Any, List, Tuple, Union + +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. +_import_module_src = dict() # Source code for temporary modules created during pickle import. + + +def import_handler(handler_func): + """Function decorator for declaring custom import handlers.""" + _import_handlers.append(handler_func) + return handler_func + + +class Network: + """Generic network abstraction. + + Acts as a convenience wrapper for a parameterized network construction + function, providing several utility methods and convenient access to + the inputs/outputs/weights. + + Network objects can be safely pickled and unpickled for long-term + archival purposes. The pickling works reliably as long as the underlying + network construction function is defined in a standalone Python module + that has no side effects or application-specific imports. + + Args: + name: Network name. Used to select TensorFlow name and variable scopes. + func_name: Fully qualified name of the underlying network construction function, or a top-level function object. + static_kwargs: Keyword arguments to be passed in to the network construction function. + + Attributes: + name: User-specified name, defaults to build func name if None. + scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name. + static_kwargs: Arguments passed to the user-supplied build func. + components: Container for sub-networks. Passed to the build func, and retained between calls. + num_inputs: Number of input tensors. + num_outputs: Number of output tensors. + input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension. + output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension. + input_shape: Short-hand for input_shapes[0]. + output_shape: Short-hand for output_shapes[0]. + input_templates: Input placeholders in the template graph. + output_templates: Output tensors in the template graph. + input_names: Name string for each input. + output_names: Name string for each output. + own_vars: Variables defined by this network (local_name => var), excluding sub-networks. + vars: All variables (local_name => var). + trainables: All trainable variables (local_name => var). + var_global_to_local: Mapping from variable global names to local names. + """ + + def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): + tfutil.assert_tf_initialized() + assert isinstance(name, str) or name is None + assert func_name is not None + assert isinstance(func_name, str) or util.is_top_level_function(func_name) + assert util.is_pickleable(static_kwargs) + + self._init_fields() + self.name = name + self.static_kwargs = util.EasyDict(static_kwargs) + + # Locate the user-specified network build function. + if util.is_top_level_function(func_name): + func_name = util.get_top_level_function_name(func_name) + module, self._build_func_name = util.get_module_from_obj_name(func_name) + self._build_func = util.get_obj_from_module(module, self._build_func_name) + assert callable(self._build_func) + + # Dig up source code for the module containing the build function. + self._build_module_src = _import_module_src.get(module, None) + if self._build_module_src is None: + self._build_module_src = inspect.getsource(module) + + # Init TensorFlow graph. + self._init_graph() + self.reset_own_vars() + + def _init_fields(self) -> None: + self.name = None + self.scope = None + self.static_kwargs = util.EasyDict() + self.components = util.EasyDict() + self.num_inputs = 0 + self.num_outputs = 0 + self.input_shapes = [[]] + self.output_shapes = [[]] + self.input_shape = [] + self.output_shape = [] + self.input_templates = [] + self.output_templates = [] + self.input_names = [] + self.output_names = [] + self.own_vars = OrderedDict() + self.vars = OrderedDict() + self.trainables = OrderedDict() + self.var_global_to_local = OrderedDict() + + self._build_func = None # User-supplied build function that constructs the network. + self._build_func_name = None # Name of the build function. + self._build_module_src = None # Full source code of the module containing the build function. + self._run_cache = dict() # Cached graph data for Network.run(). + + def _init_graph(self) -> None: + # Collect inputs. + self.input_names = [] + + for param in inspect.signature(self._build_func).parameters.values(): + if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: + self.input_names.append(param.name) + + self.num_inputs = len(self.input_names) + assert self.num_inputs >= 1 + + # Choose name and scope. + if self.name is None: + self.name = self._build_func_name + assert re.match("^[A-Za-z0-9_.\\-]*$", self.name) + with tf.name_scope(None): + self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True) + + # Finalize build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs["is_template_graph"] = True + build_kwargs["components"] = self.components + + # Build template graph. + with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes + assert tf.get_variable_scope().name == self.scope + assert tf.get_default_graph().get_name_scope() == self.scope + with tf.control_dependencies(None): # ignore surrounding control dependencies + self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names] + out_expr = self._build_func(*self.input_templates, **build_kwargs) + + # Collect outputs. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + self.num_outputs = len(self.output_templates) + assert self.num_outputs >= 1 + assert all(tfutil.is_tf_expression(t) for t in self.output_templates) + + # Perform sanity checks. + if any(t.shape.ndims is None for t in self.input_templates): + raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") + if any(t.shape.ndims is None for t in self.output_templates): + raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") + if any(not isinstance(comp, Network) for comp in self.components.values()): + raise ValueError("Components of a Network must be Networks themselves.") + if len(self.components) != len(set(comp.name for comp in self.components.values())): + raise ValueError("Components of a Network must have unique names.") + + # List inputs and outputs. + self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates] + self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates] + self.input_shape = self.input_shapes[0] + self.output_shape = self.output_shapes[0] + self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] + + # List variables. + self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) + self.vars = OrderedDict(self.own_vars) + self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items()) + self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) + self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) + + def reset_own_vars(self) -> None: + """Re-initialize all variables of this network, excluding sub-networks.""" + tfutil.run([var.initializer for var in self.own_vars.values()]) + + def reset_vars(self) -> None: + """Re-initialize all variables of this network, including sub-networks.""" + tfutil.run([var.initializer for var in self.vars.values()]) + + def reset_trainables(self) -> None: + """Re-initialize all trainable variables of this network, including sub-networks.""" + tfutil.run([var.initializer for var in self.trainables.values()]) + + def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: + """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).""" + assert len(in_expr) == self.num_inputs + assert not all(expr is None for expr in in_expr) + + # Finalize build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs.update(dynamic_kwargs) + build_kwargs["is_template_graph"] = False + build_kwargs["components"] = self.components + + # Build TensorFlow graph to evaluate the network. + with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): + assert tf.get_variable_scope().name == self.scope + valid_inputs = [expr for expr in in_expr if expr is not None] + final_inputs = [] + for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): + if expr is not None: + expr = tf.identity(expr, name=name) + else: + expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) + final_inputs.append(expr) + out_expr = self._build_func(*final_inputs, **build_kwargs) + + # Propagate input shapes back to the user-specified expressions. + for expr, final in zip(in_expr, final_inputs): + if isinstance(expr, tf.Tensor): + expr.set_shape(final.shape) + + # Express outputs in the desired format. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + if return_as_list: + out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + return out_expr + + def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: + """Get the local name of a given variable, without any surrounding name scopes.""" + assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) + global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name + return self.var_global_to_local[global_name] + + def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: + """Find variable by local or global name.""" + assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) + return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name + + def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: + """Get the value of a given variable as NumPy array. + Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" + return self.find_var(var_or_local_name).eval() + + def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: + """Set the value of a given variable based on the given NumPy array. + Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" + tfutil.set_vars({self.find_var(var_or_local_name): new_value}) + + def __getstate__(self) -> dict: + """Pickle export.""" + state = dict() + state["version"] = 3 + state["name"] = self.name + state["static_kwargs"] = dict(self.static_kwargs) + state["components"] = dict(self.components) + state["build_module_src"] = self._build_module_src + state["build_func_name"] = self._build_func_name + state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values())))) + return state + + def __setstate__(self, state: dict) -> None: + """Pickle import.""" + # pylint: disable=attribute-defined-outside-init + tfutil.assert_tf_initialized() + self._init_fields() + + # Execute custom import handlers. + for handler in _import_handlers: + state = handler(state) + + # Set basic fields. + assert state["version"] in [2, 3] + self.name = state["name"] + self.static_kwargs = util.EasyDict(state["static_kwargs"]) + self.components = util.EasyDict(state.get("components", {})) + self._build_module_src = state["build_module_src"] + self._build_func_name = state["build_func_name"] + + # Create temporary module from the imported source code. + module_name = "_tflib_network_import_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _import_module_src[module] = self._build_module_src + exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used + + # Locate network build function in the temporary module. + self._build_func = util.get_obj_from_module(module, self._build_func_name) + assert callable(self._build_func) + + # Init TensorFlow graph. + self._init_graph() + self.reset_own_vars() + tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]}) + + def clone(self, name: str = None, **new_static_kwargs) -> "Network": + """Create a clone of this network with its own copy of the variables.""" + # pylint: disable=protected-access + net = object.__new__(Network) + net._init_fields() + net.name = name if name is not None else self.name + net.static_kwargs = util.EasyDict(self.static_kwargs) + net.static_kwargs.update(new_static_kwargs) + net._build_module_src = self._build_module_src + net._build_func_name = self._build_func_name + net._build_func = self._build_func + net._init_graph() + net.copy_vars_from(self) + return net + + def copy_own_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, excluding sub-networks.""" + names = [name for name in self.own_vars.keys() if name in src_net.own_vars] + tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) + + def copy_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, including sub-networks.""" + names = [name for name in self.vars.keys() if name in src_net.vars] + tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) + + def copy_trainables_from(self, src_net: "Network") -> None: + """Copy the values of all trainable variables from the given network, including sub-networks.""" + names = [name for name in self.trainables.keys() if name in src_net.trainables] + tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names})) + + def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": + """Create new network with the given parameters, and copy all variables from this network.""" + if new_name is None: + new_name = self.name + static_kwargs = dict(self.static_kwargs) + static_kwargs.update(new_static_kwargs) + net = Network(name=new_name, func_name=new_func_name, **static_kwargs) + net.copy_vars_from(self) + return net + + def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: + """Construct a TensorFlow op that updates the variables of this network + to be slightly closer to those of the given network.""" + with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): + ops = [] + for name, var in self.vars.items(): + if name in src_net.vars: + cur_beta = beta if name in self.trainables else beta_nontrainable + new_value = tfutil.lerp(src_net.vars[name], var, cur_beta) + ops.append(var.assign(new_value)) + return tf.group(*ops) + + def run(self, + *in_arrays: Tuple[Union[np.ndarray, None], ...], + input_transform: dict = None, + output_transform: dict = None, + return_as_list: bool = False, + print_progress: bool = False, + minibatch_size: int = None, + num_gpus: int = 1, + assume_frozen: bool = False, + **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: + """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). + + Args: + input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the input + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the output + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. + print_progress: Print progress to the console? Useful for very large input arrays. + minibatch_size: Maximum minibatch size to use, None = disable batching. + num_gpus: Number of GPUs to use. + assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. + dynamic_kwargs: Additional keyword arguments to be passed into the network build function. + """ + assert len(in_arrays) == self.num_inputs + assert not all(arr is None for arr in in_arrays) + assert input_transform is None or util.is_top_level_function(input_transform["func"]) + assert output_transform is None or util.is_top_level_function(output_transform["func"]) + output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) + num_items = in_arrays[0].shape[0] + if minibatch_size is None: + minibatch_size = num_items + + # Construct unique hash key from all arguments that affect the TensorFlow graph. + key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) + def unwind_key(obj): + if isinstance(obj, dict): + return [(key, unwind_key(value)) for key, value in sorted(obj.items())] + if callable(obj): + return util.get_top_level_function_name(obj) + return obj + key = repr(unwind_key(key)) + + # Build graph. + if key not in self._run_cache: + with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): + with tf.device("/cpu:0"): + in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] + in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) + + out_split = [] + for gpu in range(num_gpus): + with tf.device("/gpu:%d" % gpu): + net_gpu = self.clone() if assume_frozen else self + in_gpu = in_split[gpu] + + if input_transform is not None: + in_kwargs = dict(input_transform) + in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) + in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) + + assert len(in_gpu) == self.num_inputs + out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) + + if output_transform is not None: + out_kwargs = dict(output_transform) + out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) + out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) + + assert len(out_gpu) == self.num_outputs + out_split.append(out_gpu) + + with tf.device("/cpu:0"): + out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] + self._run_cache[key] = in_expr, out_expr + + # Run minibatches. + in_expr, out_expr = self._run_cache[key] + out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr] + + for mb_begin in range(0, num_items, minibatch_size): + if print_progress: + print("\r%d / %d" % (mb_begin, num_items), end="") + + mb_end = min(mb_begin + minibatch_size, num_items) + mb_num = mb_end - mb_begin + mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] + mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) + + for dst, src in zip(out_arrays, mb_out): + dst[mb_begin: mb_end] = src + + # Done. + if print_progress: + print("\r%d / %d" % (num_items, num_items)) + + if not return_as_list: + out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) + return out_arrays + + def list_ops(self) -> List[TfExpression]: + include_prefix = self.scope + "/" + exclude_prefix = include_prefix + "_" + ops = tf.get_default_graph().get_operations() + ops = [op for op in ops if op.name.startswith(include_prefix)] + ops = [op for op in ops if not op.name.startswith(exclude_prefix)] + return ops + + def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: + """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to + individual layers of the network. Mainly intended to be used for reporting.""" + layers = [] + + def recurse(scope, parent_ops, parent_vars, level): + # Ignore specific patterns. + if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): + return + + # Filter ops and vars by scope. + global_prefix = scope + "/" + local_prefix = global_prefix[len(self.scope) + 1:] + cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] + cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] + if not cur_ops and not cur_vars: + return + + # Filter out all ops related to variables. + for var in [op for op in cur_ops if op.type.startswith("Variable")]: + var_prefix = var.name + "/" + cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] + + # Scope does not contain ops as immediate children => recurse deeper. + contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops) + if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1: + visited = set() + for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: + token = rel_name.split("/")[0] + if token not in visited: + recurse(global_prefix + token, cur_ops, cur_vars, level + 1) + visited.add(token) + return + + # Report layer. + layer_name = scope[len(self.scope) + 1:] + layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] + layer_trainables = [var for _name, var in cur_vars if var.trainable] + layers.append((layer_name, layer_output, layer_trainables)) + + recurse(self.scope, self.list_ops(), list(self.vars.items()), 0) + return layers + + def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: + """Print a summary table of the network structure.""" + rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] + rows += [["---"] * 4] + total_params = 0 + + for layer_name, layer_output, layer_trainables in self.list_layers(): + num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables) + weights = [var for var in layer_trainables if var.name.endswith("/weight:0")] + weights.sort(key=lambda x: len(x.name)) + if len(weights) == 0 and len(layer_trainables) == 1: + weights = layer_trainables + total_params += num_params + + if not hide_layers_with_no_params or num_params != 0: + num_params_str = str(num_params) if num_params > 0 else "-" + output_shape_str = str(layer_output.shape) + weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" + rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] + + rows += [["---"] * 4] + rows += [["Total", str(total_params), "", ""]] + + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) + print() + + def setup_weight_histograms(self, title: str = None) -> None: + """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" + if title is None: + title = self.name + + with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): + for local_name, var in self.trainables.items(): + if "/" in local_name: + p = local_name.split("/") + name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) + else: + name = title + "_toplevel/" + local_name + + tf.summary.histogram(name, var) + +#---------------------------------------------------------------------------- +# Backwards-compatible emulation of legacy output transformation in Network.run(). + +_print_legacy_warning = True + +def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): + global _print_legacy_warning + legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] + if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): + return output_transform, dynamic_kwargs + + if _print_legacy_warning: + _print_legacy_warning = False + print() + print("WARNING: Old-style output transformations in Network.run() are deprecated.") + print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") + print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") + print() + assert output_transform is None + + new_kwargs = dict(dynamic_kwargs) + new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} + new_transform["func"] = _legacy_output_transform_func + return new_transform, new_kwargs + +def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): + if out_mul != 1.0: + expr = [x * out_mul for x in expr] + + if out_add != 0.0: + expr = [x + out_add for x in expr] + + if out_shrink > 1: + ksize = [1, 1, out_shrink, out_shrink] + expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] + + if out_dtype is not None: + if tf.as_dtype(out_dtype).is_integer: + expr = [tf.round(x) for x in expr] + expr = [tf.saturate_cast(x, out_dtype) for x in expr] + return expr diff --git a/models/stylegan/stylegan_tf/dnnlib/tflib/optimizer.py b/models/stylegan/stylegan_tf/dnnlib/tflib/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed88cb236365234597f8734299fbb315c56cc73 --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/tflib/optimizer.py @@ -0,0 +1,214 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Helper wrapper for a Tensorflow optimizer.""" + +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import List, Union + +from . import autosummary +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +try: + # TensorFlow 1.13 + from tensorflow.python.ops import nccl_ops +except: + # Older TensorFlow versions + import tensorflow.contrib.nccl as nccl_ops + +class Optimizer: + """A Wrapper for tf.train.Optimizer. + + Automatically takes care of: + - Gradient averaging for multi-GPU training. + - Dynamic loss scaling and typecasts for FP16 training. + - Ignoring corrupted gradients that contain NaNs/Infs. + - Reporting statistics. + - Well-chosen default settings. + """ + + def __init__(self, + name: str = "Train", + tf_optimizer: str = "tf.train.AdamOptimizer", + learning_rate: TfExpressionEx = 0.001, + use_loss_scaling: bool = False, + loss_scaling_init: float = 64.0, + loss_scaling_inc: float = 0.0005, + loss_scaling_dec: float = 1.0, + **kwargs): + + # Init fields. + self.name = name + self.learning_rate = tf.convert_to_tensor(learning_rate) + self.id = self.name.replace("/", ".") + self.scope = tf.get_default_graph().unique_name(self.id) + self.optimizer_class = util.get_obj_by_name(tf_optimizer) + self.optimizer_kwargs = dict(kwargs) + self.use_loss_scaling = use_loss_scaling + self.loss_scaling_init = loss_scaling_init + self.loss_scaling_inc = loss_scaling_inc + self.loss_scaling_dec = loss_scaling_dec + self._grad_shapes = None # [shape, ...] + self._dev_opt = OrderedDict() # device => optimizer + self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] + self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) + self._updates_applied = False + + def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: + """Register the gradients of the given loss function with respect to the given variables. + Intended to be called once per GPU.""" + assert not self._updates_applied + + # Validate arguments. + if isinstance(trainable_vars, dict): + trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars + + assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 + assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) + + if self._grad_shapes is None: + self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] + + assert len(trainable_vars) == len(self._grad_shapes) + assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) + + dev = loss.device + + assert all(var.device == dev for var in trainable_vars) + + # Register device and compute gradients. + with tf.name_scope(self.id + "_grad"), tf.device(dev): + if dev not in self._dev_opt: + opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) + assert callable(self.optimizer_class) + self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) + self._dev_grads[dev] = [] + + loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) + grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage + grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros + self._dev_grads[dev].append(grads) + + def apply_updates(self) -> tf.Operation: + """Construct training op to update the registered variables based on their gradients.""" + tfutil.assert_tf_initialized() + assert not self._updates_applied + self._updates_applied = True + devices = list(self._dev_grads.keys()) + total_grads = sum(len(grads) for grads in self._dev_grads.values()) + assert len(devices) >= 1 and total_grads >= 1 + ops = [] + + with tfutil.absolute_name_scope(self.scope): + # Cast gradients to FP32 and calculate partial sum within each device. + dev_grads = OrderedDict() # device => [(grad, var), ...] + + for dev_idx, dev in enumerate(devices): + with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): + sums = [] + + for gv in zip(*self._dev_grads[dev]): + assert all(v is gv[0][1] for g, v in gv) + g = [tf.cast(g, tf.float32) for g, v in gv] + g = g[0] if len(g) == 1 else tf.add_n(g) + sums.append((g, gv[0][1])) + + dev_grads[dev] = sums + + # Sum gradients across devices. + if len(devices) > 1: + with tf.name_scope("SumAcrossGPUs"), tf.device(None): + for var_idx, grad_shape in enumerate(self._grad_shapes): + g = [dev_grads[dev][var_idx][0] for dev in devices] + + if np.prod(grad_shape): # nccl does not support zero-sized tensors + g = nccl_ops.all_sum(g) + + for dev, gg in zip(devices, g): + dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) + + # Apply updates separately on each device. + for dev_idx, (dev, grads) in enumerate(dev_grads.items()): + with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): + # Scale gradients as needed. + if self.use_loss_scaling or total_grads > 1: + with tf.name_scope("Scale"): + coef = tf.constant(np.float32(1.0 / total_grads), name="coef") + coef = self.undo_loss_scaling(coef) + grads = [(g * coef, v) for g, v in grads] + + # Check for overflows. + with tf.name_scope("CheckOverflow"): + grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) + + # Update weights and adjust loss scaling. + with tf.name_scope("UpdateWeights"): + # pylint: disable=cell-var-from-loop + opt = self._dev_opt[dev] + ls_var = self.get_loss_scaling_var(dev) + + if not self.use_loss_scaling: + ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) + else: + ops.append(tf.cond(grad_ok, + lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), + lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) + + # Report statistics on the last device. + if dev == devices[-1]: + with tf.name_scope("Statistics"): + ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) + ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) + + if self.use_loss_scaling: + ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) + + # Initialize variables and group everything into a single op. + self.reset_optimizer_state() + tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) + + return tf.group(*ops, name="TrainingOp") + + def reset_optimizer_state(self) -> None: + """Reset internal state of the underlying optimizer.""" + tfutil.assert_tf_initialized() + tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) + + def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: + """Get or create variable representing log2 of the current dynamic loss scaling factor.""" + if not self.use_loss_scaling: + return None + + if device not in self._dev_ls_var: + with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): + self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") + + return self._dev_ls_var[device] + + def apply_loss_scaling(self, value: TfExpression) -> TfExpression: + """Apply dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + + if not self.use_loss_scaling: + return value + + return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) + + def undo_loss_scaling(self, value: TfExpression) -> TfExpression: + """Undo the effect of dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + + if not self.use_loss_scaling: + return value + + return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type diff --git a/models/stylegan/stylegan_tf/dnnlib/tflib/tfutil.py b/models/stylegan/stylegan_tf/dnnlib/tflib/tfutil.py new file mode 100644 index 0000000000000000000000000000000000000000..a431a4d4d18a32c9cd44a14ce89f35e038dc312c --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/tflib/tfutil.py @@ -0,0 +1,240 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Miscellaneous helper utils for Tensorflow.""" + +import os +import numpy as np +import tensorflow as tf + +from typing import Any, Iterable, List, Union + +TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] +"""A type that represents a valid Tensorflow expression.""" + +TfExpressionEx = Union[TfExpression, int, float, np.ndarray] +"""A type that can be converted to a valid Tensorflow expression.""" + + +def run(*args, **kwargs) -> Any: + """Run the specified ops in the default session.""" + assert_tf_initialized() + return tf.get_default_session().run(*args, **kwargs) + + +def is_tf_expression(x: Any) -> bool: + """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" + return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) + + +def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: + """Convert a Tensorflow shape to a list of ints.""" + return [dim.value for dim in shape] + + +def flatten(x: TfExpressionEx) -> TfExpression: + """Shortcut function for flattening a tensor.""" + with tf.name_scope("Flatten"): + return tf.reshape(x, [-1]) + + +def log2(x: TfExpressionEx) -> TfExpression: + """Logarithm in base 2.""" + with tf.name_scope("Log2"): + return tf.log(x) * np.float32(1.0 / np.log(2.0)) + + +def exp2(x: TfExpressionEx) -> TfExpression: + """Exponent in base 2.""" + with tf.name_scope("Exp2"): + return tf.exp(x * np.float32(np.log(2.0))) + + +def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: + """Linear interpolation.""" + with tf.name_scope("Lerp"): + return a + (b - a) * t + + +def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: + """Linear interpolation with clip.""" + with tf.name_scope("LerpClip"): + return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) + + +def absolute_name_scope(scope: str) -> tf.name_scope: + """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" + return tf.name_scope(scope + "/") + + +def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: + """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" + return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) + + +def _sanitize_tf_config(config_dict: dict = None) -> dict: + # Defaults. + cfg = dict() + cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. + cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. + cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. + cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. + cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. + + # User overrides. + if config_dict is not None: + cfg.update(config_dict) + return cfg + + +def init_tf(config_dict: dict = None) -> None: + """Initialize TensorFlow session using good default settings.""" + # Skip if already initialized. + if tf.get_default_session() is not None: + return + + # Setup config dict and random seeds. + cfg = _sanitize_tf_config(config_dict) + np_random_seed = cfg["rnd.np_random_seed"] + if np_random_seed is not None: + np.random.seed(np_random_seed) + tf_random_seed = cfg["rnd.tf_random_seed"] + if tf_random_seed == "auto": + tf_random_seed = np.random.randint(1 << 31) + if tf_random_seed is not None: + tf.set_random_seed(tf_random_seed) + + # Setup environment variables. + for key, value in list(cfg.items()): + fields = key.split(".") + if fields[0] == "env": + assert len(fields) == 2 + os.environ[fields[1]] = str(value) + + # Create default TensorFlow session. + create_session(cfg, force_as_default=True) + + +def assert_tf_initialized(): + """Check that TensorFlow session has been initialized.""" + if tf.get_default_session() is None: + raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") + + +def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: + """Create tf.Session based on config dict.""" + # Setup TensorFlow config proto. + cfg = _sanitize_tf_config(config_dict) + config_proto = tf.ConfigProto() + for key, value in cfg.items(): + fields = key.split(".") + if fields[0] not in ["rnd", "env"]: + obj = config_proto + for field in fields[:-1]: + obj = getattr(obj, field) + setattr(obj, fields[-1], value) + + # Create session. + session = tf.Session(config=config_proto) + if force_as_default: + # pylint: disable=protected-access + session._default_session = session.as_default() + session._default_session.enforce_nesting = False + session._default_session.__enter__() # pylint: disable=no-member + + return session + + +def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: + """Initialize all tf.Variables that have not already been initialized. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tf.variables_initializer(tf.report_uninitialized_variables()).run() + """ + assert_tf_initialized() + if target_vars is None: + target_vars = tf.global_variables() + + test_vars = [] + test_ops = [] + + with tf.control_dependencies(None): # ignore surrounding control_dependencies + for var in target_vars: + assert is_tf_expression(var) + + try: + tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) + except KeyError: + # Op does not exist => variable may be uninitialized. + test_vars.append(var) + + with absolute_name_scope(var.name.split(":")[0]): + test_ops.append(tf.is_variable_initialized(var)) + + init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] + run([var.initializer for var in init_vars]) + + +def set_vars(var_to_value_dict: dict) -> None: + """Set the values of given tf.Variables. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] + """ + assert_tf_initialized() + ops = [] + feed_dict = {} + + for var, value in var_to_value_dict.items(): + assert is_tf_expression(var) + + try: + setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op + except KeyError: + with absolute_name_scope(var.name.split(":")[0]): + with tf.control_dependencies(None): # ignore surrounding control_dependencies + setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter + + ops.append(setter) + feed_dict[setter.op.inputs[1]] = value + + run(ops, feed_dict) + + +def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): + """Create tf.Variable with large initial value without bloating the tf graph.""" + assert_tf_initialized() + assert isinstance(initial_value, np.ndarray) + zeros = tf.zeros(initial_value.shape, initial_value.dtype) + var = tf.Variable(zeros, *args, **kwargs) + set_vars({var: initial_value}) + return var + + +def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): + """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. + Can be used as an input transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if nhwc_to_nchw: + images = tf.transpose(images, [0, 3, 1, 2]) + return (images - drange[0]) * ((drange[1] - drange[0]) / 255) + + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if shrink > 1: + ksize = [1, 1, shrink, shrink] + images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") + if nchw_to_nhwc: + images = tf.transpose(images, [0, 2, 3, 1]) + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + return tf.saturate_cast(images, tf.uint8) diff --git a/models/stylegan/stylegan_tf/dnnlib/util.py b/models/stylegan/stylegan_tf/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..133ef764c0707d9384a33f0350ba71b1e624072f --- /dev/null +++ b/models/stylegan/stylegan_tf/dnnlib/util.py @@ -0,0 +1,405 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: str) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + return obj.__module__ + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert is_url(url) + assert num_attempts >= 1 + + # Lookup from cache. + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache_dir is not None: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + return open(cache_files[0], "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive quota exceeded") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache_dir is not None: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + + # Return data as file object. + return io.BytesIO(url_data) diff --git a/models/stylegan/stylegan_tf/generate_figures.py b/models/stylegan/stylegan_tf/generate_figures.py new file mode 100644 index 0000000000000000000000000000000000000000..45b68b86146198c701a66fb8ba7a363d901d6951 --- /dev/null +++ b/models/stylegan/stylegan_tf/generate_figures.py @@ -0,0 +1,161 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators.""" + +import os +import pickle +import numpy as np +import PIL.Image +import dnnlib +import dnnlib.tflib as tflib +import config + +#---------------------------------------------------------------------------- +# Helpers for loading and using pre-trained generators. + +url_ffhq = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl +url_celebahq = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl +url_bedrooms = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl +url_cars = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl +url_cats = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl + +synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) + +_Gs_cache = dict() + +def load_Gs(url): + if url not in _Gs_cache: + with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: + _G, _D, Gs = pickle.load(f) + _Gs_cache[url] = Gs + return _Gs_cache[url] + +#---------------------------------------------------------------------------- +# Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images. + +def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed): + print(png) + latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1]) + images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb] + + canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white') + image_iter = iter(list(images)) + for col, lod in enumerate(lods): + for row in range(rows * 2**lod): + image = PIL.Image.fromarray(next(image_iter), 'RGB') + image = image.crop((cx, cy, cx + cw, cy + ch)) + image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS) + canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod)) + canvas.save(png) + +#---------------------------------------------------------------------------- +# Figure 3: Style mixing. + +def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges): + print(png) + src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) + dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds) + src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] + dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component] + src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs) + dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs) + + canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white') + for col, src_image in enumerate(list(src_images)): + canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0)) + for row, dst_image in enumerate(list(dst_images)): + canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h)) + row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds)) + row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]] + row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) + for col, image in enumerate(list(row_images)): + canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h)) + canvas.save(png) + +#---------------------------------------------------------------------------- +# Figure 4: Noise detail. + +def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds): + print(png) + canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white') + for row, seed in enumerate(seeds): + latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples) + images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs) + canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h)) + for i in range(4): + crop = PIL.Image.fromarray(images[i + 1], 'RGB') + crop = crop.crop((650, 180, 906, 436)) + crop = crop.resize((w//2, h//2), PIL.Image.NEAREST) + canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2)) + diff = np.std(np.mean(images, axis=3), axis=0) * 4 + diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8) + canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h)) + canvas.save(png) + +#---------------------------------------------------------------------------- +# Figure 5: Noise components. + +def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips): + print(png) + Gsc = Gs.clone() + noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')] + noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...] + latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) + all_images = [] + for noise_range in noise_ranges: + tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)}) + range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs) + range_images[flips, :, :] = range_images[flips, :, ::-1] + all_images.append(list(range_images)) + + canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white') + for col, col_images in enumerate(zip(*all_images)): + canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0)) + canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0)) + canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h)) + canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h)) + canvas.save(png) + +#---------------------------------------------------------------------------- +# Figure 8: Truncation trick. + +def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis): + print(png) + latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds) + dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component] + dlatent_avg = Gs.get_var('dlatent_avg') # [component] + + canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white') + for row, dlatent in enumerate(list(dlatents)): + row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg + row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs) + for col, image in enumerate(list(row_images)): + canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h)) + canvas.save(png) + +#---------------------------------------------------------------------------- +# Main program. + +def main(): + tflib.init_tf() + os.makedirs(config.result_dir, exist_ok=True) + draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5) + draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)]) + draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012]) + draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1]) + draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1]) + draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0) + draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2) + draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/metrics/__init__.py b/models/stylegan/stylegan_tf/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db8124b132f91216c0ded226f20ea3a046734728 --- /dev/null +++ b/models/stylegan/stylegan_tf/metrics/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +# empty diff --git a/models/stylegan/stylegan_tf/metrics/frechet_inception_distance.py b/models/stylegan/stylegan_tf/metrics/frechet_inception_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..41f71fe4bfb85218cc283b3f7bc3a34fea5f790d --- /dev/null +++ b/models/stylegan/stylegan_tf/metrics/frechet_inception_distance.py @@ -0,0 +1,72 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Frechet Inception Distance (FID).""" + +import os +import numpy as np +import scipy +import tensorflow as tf +import dnnlib.tflib as tflib + +from metrics import metric_base +from training import misc + +#---------------------------------------------------------------------------- + +class FID(metric_base.MetricBase): + def __init__(self, num_images, minibatch_per_gpu, **kwargs): + super().__init__(**kwargs) + self.num_images = num_images + self.minibatch_per_gpu = minibatch_per_gpu + + def _evaluate(self, Gs, num_gpus): + minibatch_size = num_gpus * self.minibatch_per_gpu + inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl + activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32) + + # Calculate statistics for reals. + cache_file = self._get_cache_file_for_reals(num_images=self.num_images) + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + if os.path.isfile(cache_file): + mu_real, sigma_real = misc.load_pkl(cache_file) + else: + for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)): + begin = idx * minibatch_size + end = min(begin + minibatch_size, self.num_images) + activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True) + if end == self.num_images: + break + mu_real = np.mean(activations, axis=0) + sigma_real = np.cov(activations, rowvar=False) + misc.save_pkl((mu_real, sigma_real), cache_file) + + # Construct TensorFlow graph. + result_expr = [] + for gpu_idx in range(num_gpus): + with tf.device('/gpu:%d' % gpu_idx): + Gs_clone = Gs.clone() + inception_clone = inception.clone() + latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) + images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) + images = tflib.convert_images_to_uint8(images) + result_expr.append(inception_clone.get_output_for(images)) + + # Calculate statistics for fakes. + for begin in range(0, self.num_images, minibatch_size): + end = min(begin + minibatch_size, self.num_images) + activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin] + mu_fake = np.mean(activations, axis=0) + sigma_fake = np.cov(activations, rowvar=False) + + # Calculate FID. + m = np.square(mu_fake - mu_real).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member + dist = m + np.trace(sigma_fake + sigma_real - 2*s) + self._report_result(np.real(dist)) + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/metrics/linear_separability.py b/models/stylegan/stylegan_tf/metrics/linear_separability.py new file mode 100644 index 0000000000000000000000000000000000000000..e50be5a0fea00eba7af2d05cccf74bacedbea1c3 --- /dev/null +++ b/models/stylegan/stylegan_tf/metrics/linear_separability.py @@ -0,0 +1,177 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Linear Separability (LS).""" + +from collections import defaultdict +import numpy as np +import sklearn.svm +import tensorflow as tf +import dnnlib.tflib as tflib + +from metrics import metric_base +from training import misc + +#---------------------------------------------------------------------------- + +classifier_urls = [ + 'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl + 'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl + 'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl + 'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl + 'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl + 'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl + 'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl + 'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl + 'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl + 'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl + 'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl + 'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl + 'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl + 'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl + 'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl + 'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl + 'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl + 'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl + 'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl + 'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl + 'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl + 'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl + 'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl + 'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl + 'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl + 'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl + 'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl + 'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl + 'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl + 'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl + 'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl + 'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl + 'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl + 'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl + 'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl + 'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl + 'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl + 'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl + 'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl + 'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl +] + +#---------------------------------------------------------------------------- + +def prob_normalize(p): + p = np.asarray(p).astype(np.float32) + assert len(p.shape) == 2 + return p / np.sum(p) + +def mutual_information(p): + p = prob_normalize(p) + px = np.sum(p, axis=1) + py = np.sum(p, axis=0) + result = 0.0 + for x in range(p.shape[0]): + p_x = px[x] + for y in range(p.shape[1]): + p_xy = p[x][y] + p_y = py[y] + if p_xy > 0.0: + result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output + return result + +def entropy(p): + p = prob_normalize(p) + result = 0.0 + for x in range(p.shape[0]): + for y in range(p.shape[1]): + p_xy = p[x][y] + if p_xy > 0.0: + result -= p_xy * np.log2(p_xy) + return result + +def conditional_entropy(p): + # H(Y|X) where X corresponds to axis 0, Y to axis 1 + # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0? + p = prob_normalize(p) + y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y) + return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up. + +#---------------------------------------------------------------------------- + +class LS(metric_base.MetricBase): + def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs): + assert num_keep <= num_samples + super().__init__(**kwargs) + self.num_samples = num_samples + self.num_keep = num_keep + self.attrib_indices = attrib_indices + self.minibatch_per_gpu = minibatch_per_gpu + + def _evaluate(self, Gs, num_gpus): + minibatch_size = num_gpus * self.minibatch_per_gpu + + # Construct TensorFlow graph for each GPU. + result_expr = [] + for gpu_idx in range(num_gpus): + with tf.device('/gpu:%d' % gpu_idx): + Gs_clone = Gs.clone() + + # Generate images. + latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) + dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True) + images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True) + + # Downsample to 256x256. The attribute classifiers were built for 256x256. + if images.shape[2] > 256: + factor = images.shape[2] // 256 + images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) + images = tf.reduce_mean(images, axis=[3, 5]) + + # Run classifier for each attribute. + result_dict = dict(latents=latents, dlatents=dlatents[:,-1]) + for attrib_idx in self.attrib_indices: + classifier = misc.load_pkl(classifier_urls[attrib_idx]) + logits = classifier.get_output_for(images, None) + predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1)) + result_dict[attrib_idx] = predictions + result_expr.append(result_dict) + + # Sampling loop. + results = [] + for _ in range(0, self.num_samples, minibatch_size): + results += tflib.run(result_expr) + results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()} + + # Calculate conditional entropy for each attribute. + conditional_entropies = defaultdict(list) + for attrib_idx in self.attrib_indices: + # Prune the least confident samples. + pruned_indices = list(range(self.num_samples)) + pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i])) + pruned_indices = pruned_indices[:self.num_keep] + + # Fit SVM to the remaining samples. + svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1) + for space in ['latents', 'dlatents']: + svm_inputs = results[space][pruned_indices] + try: + svm = sklearn.svm.LinearSVC() + svm.fit(svm_inputs, svm_targets) + svm.score(svm_inputs, svm_targets) + svm_outputs = svm.predict(svm_inputs) + except: + svm_outputs = svm_targets # assume perfect prediction + + # Calculate conditional entropy. + p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)] + conditional_entropies[space].append(conditional_entropy(p)) + + # Calculate separability scores. + scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()} + self._report_result(scores['latents'], suffix='_z') + self._report_result(scores['dlatents'], suffix='_w') + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/metrics/metric_base.py b/models/stylegan/stylegan_tf/metrics/metric_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0db82adecb60260393eaf82bd991575d79085787 --- /dev/null +++ b/models/stylegan/stylegan_tf/metrics/metric_base.py @@ -0,0 +1,142 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Common definitions for GAN metrics.""" + +import os +import time +import hashlib +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib + +import config +from training import misc +from training import dataset + +#---------------------------------------------------------------------------- +# Standard metrics. + +fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) +ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) +ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16) +ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16) +ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16) +ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4) +dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging + +#---------------------------------------------------------------------------- +# Base class for metrics. + +class MetricBase: + def __init__(self, name): + self.name = name + self._network_pkl = None + self._dataset_args = None + self._mirror_augment = None + self._results = [] + self._eval_time = None + + def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True): + self._network_pkl = network_pkl + self._dataset_args = dataset_args + self._mirror_augment = mirror_augment + self._results = [] + + if (dataset_args is None or mirror_augment is None) and run_dir is not None: + run_config = misc.parse_config_for_previous_run(run_dir) + self._dataset_args = dict(run_config['dataset']) + self._dataset_args['shuffle_mb'] = 0 + self._mirror_augment = run_config['train'].get('mirror_augment', False) + + time_begin = time.time() + with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager + _G, _D, Gs = misc.load_pkl(self._network_pkl) + self._evaluate(Gs, num_gpus=num_gpus) + self._eval_time = time.time() - time_begin + + if log_results: + result_str = self.get_result_str() + if run_dir is not None: + log = os.path.join(run_dir, 'metric-%s.txt' % self.name) + with dnnlib.util.Logger(log, 'a'): + print(result_str) + else: + print(result_str) + + def get_result_str(self): + network_name = os.path.splitext(os.path.basename(self._network_pkl))[0] + if len(network_name) > 29: + network_name = '...' + network_name[-26:] + result_str = '%-30s' % network_name + result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time) + for res in self._results: + result_str += ' ' + self.name + res.suffix + ' ' + result_str += res.fmt % res.value + return result_str + + def update_autosummaries(self): + for res in self._results: + tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value) + + def _evaluate(self, Gs, num_gpus): + raise NotImplementedError # to be overridden by subclasses + + def _report_result(self, value, suffix='', fmt='%-10.4f'): + self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)] + + def _get_cache_file_for_reals(self, extension='pkl', **kwargs): + all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment) + all_args.update(self._dataset_args) + all_args.update(kwargs) + md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8')) + dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1] + return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension)) + + def _iterate_reals(self, minibatch_size): + dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args) + while True: + images, _labels = dataset_obj.get_minibatch_np(minibatch_size) + if self._mirror_augment: + images = misc.apply_mirror_augment(images) + yield images + + def _iterate_fakes(self, Gs, minibatch_size, num_gpus): + while True: + latents = np.random.randn(minibatch_size, *Gs.input_shape[1:]) + fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True) + yield images + +#---------------------------------------------------------------------------- +# Group of multiple metrics. + +class MetricGroup: + def __init__(self, metric_kwarg_list): + self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list] + + def run(self, *args, **kwargs): + for metric in self.metrics: + metric.run(*args, **kwargs) + + def get_result_str(self): + return ' '.join(metric.get_result_str() for metric in self.metrics) + + def update_autosummaries(self): + for metric in self.metrics: + metric.update_autosummaries() + +#---------------------------------------------------------------------------- +# Dummy metric for debugging purposes. + +class DummyMetric(MetricBase): + def _evaluate(self, Gs, num_gpus): + _ = Gs, num_gpus + self._report_result(0.0) + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/metrics/perceptual_path_length.py b/models/stylegan/stylegan_tf/metrics/perceptual_path_length.py new file mode 100644 index 0000000000000000000000000000000000000000..17271cfdf1545a26ab71d309ce2180532f513bd6 --- /dev/null +++ b/models/stylegan/stylegan_tf/metrics/perceptual_path_length.py @@ -0,0 +1,108 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Perceptual Path Length (PPL).""" + +import numpy as np +import tensorflow as tf +import dnnlib.tflib as tflib + +from metrics import metric_base +from training import misc + +#---------------------------------------------------------------------------- + +# Normalize batch of vectors. +def normalize(v): + return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True)) + +# Spherical interpolation of a batch of vectors. +def slerp(a, b, t): + a = normalize(a) + b = normalize(b) + d = tf.reduce_sum(a * b, axis=-1, keepdims=True) + p = t * tf.math.acos(d) + c = normalize(b - d * a) + d = a * tf.math.cos(p) + c * tf.math.sin(p) + return normalize(d) + +#---------------------------------------------------------------------------- + +class PPL(metric_base.MetricBase): + def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs): + assert space in ['z', 'w'] + assert sampling in ['full', 'end'] + super().__init__(**kwargs) + self.num_samples = num_samples + self.epsilon = epsilon + self.space = space + self.sampling = sampling + self.minibatch_per_gpu = minibatch_per_gpu + + def _evaluate(self, Gs, num_gpus): + minibatch_size = num_gpus * self.minibatch_per_gpu + + # Construct TensorFlow graph. + distance_expr = [] + for gpu_idx in range(num_gpus): + with tf.device('/gpu:%d' % gpu_idx): + Gs_clone = Gs.clone() + noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')] + + # Generate random latents and interpolation t-values. + lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:]) + lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0) + + # Interpolate in W or Z. + if self.space == 'w': + dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True) + dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2] + dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis]) + dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon) + dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape) + else: # space == 'z' + lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2] + lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis]) + lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon) + lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape) + dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True) + + # Synthesize images. + with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch + images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False) + + # Crop only the face region. + c = int(images.shape[2] // 8) + images = images[:, :, c*3 : c*7, c*2 : c*6] + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + if images.shape[2] > 256: + factor = images.shape[2] // 256 + images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) + images = tf.reduce_mean(images, axis=[3,5]) + + # Scale dynamic range from [-1,1] to [0,255] for VGG. + images = (images + 1) * (255 / 2) + + # Evaluate perceptual distance. + img_e0, img_e1 = images[0::2], images[1::2] + distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl + distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2)) + + # Sampling loop. + all_distances = [] + for _ in range(0, self.num_samples, minibatch_size): + all_distances += tflib.run(distance_expr) + all_distances = np.concatenate(all_distances, axis=0) + + # Reject outliers. + lo = np.percentile(all_distances, 1, interpolation='lower') + hi = np.percentile(all_distances, 99, interpolation='higher') + filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances) + self._report_result(np.mean(filtered_distances)) + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/pretrained_example.py b/models/stylegan/stylegan_tf/pretrained_example.py new file mode 100644 index 0000000000000000000000000000000000000000..63baef08bfa4bf34f52a0cf63e10a0b6783ac316 --- /dev/null +++ b/models/stylegan/stylegan_tf/pretrained_example.py @@ -0,0 +1,47 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Minimal script for generating an image using pre-trained StyleGAN generator.""" + +import os +import pickle +import numpy as np +import PIL.Image +import dnnlib +import dnnlib.tflib as tflib +import config + +def main(): + # Initialize TensorFlow. + tflib.init_tf() + + # Load pre-trained network. + url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl + with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: + _G, _D, Gs = pickle.load(f) + # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. + # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. + # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. + + # Print network details. + Gs.print_layers() + + # Pick latent vector. + rnd = np.random.RandomState(5) + latents = rnd.randn(1, Gs.input_shape[1]) + + # Generate image. + fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) + + # Save image. + os.makedirs(config.result_dir, exist_ok=True) + png_filename = os.path.join(config.result_dir, 'example.png') + PIL.Image.fromarray(images[0], 'RGB').save(png_filename) + +if __name__ == "__main__": + main() diff --git a/models/stylegan/stylegan_tf/run_metrics.py b/models/stylegan/stylegan_tf/run_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..5d1597bbd4e16a2535309ea74c3559cae2a5fa53 --- /dev/null +++ b/models/stylegan/stylegan_tf/run_metrics.py @@ -0,0 +1,105 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Main entry point for training StyleGAN and ProGAN networks.""" + +import dnnlib +from dnnlib import EasyDict +import dnnlib.tflib as tflib + +import config +from metrics import metric_base +from training import misc + +#---------------------------------------------------------------------------- + +def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment): + ctx = dnnlib.RunContext(submit_config) + tflib.init_tf() + print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl)) + metric = dnnlib.util.call_func_by_name(**metric_args) + print() + metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus) + print() + ctx.close() + +#---------------------------------------------------------------------------- + +def run_snapshot(submit_config, metric_args, run_id, snapshot): + ctx = dnnlib.RunContext(submit_config) + tflib.init_tf() + print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot)) + run_dir = misc.locate_run_dir(run_id) + network_pkl = misc.locate_network_pkl(run_dir, snapshot) + metric = dnnlib.util.call_func_by_name(**metric_args) + print() + metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) + print() + ctx.close() + +#---------------------------------------------------------------------------- + +def run_all_snapshots(submit_config, metric_args, run_id): + ctx = dnnlib.RunContext(submit_config) + tflib.init_tf() + print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id)) + run_dir = misc.locate_run_dir(run_id) + network_pkls = misc.list_network_pkls(run_dir) + metric = dnnlib.util.call_func_by_name(**metric_args) + print() + for idx, network_pkl in enumerate(network_pkls): + ctx.update('', idx, len(network_pkls)) + metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) + print() + ctx.close() + +#---------------------------------------------------------------------------- + +def main(): + submit_config = dnnlib.SubmitConfig() + + # Which metrics to evaluate? + metrics = [] + metrics += [metric_base.fid50k] + #metrics += [metric_base.ppl_zfull] + #metrics += [metric_base.ppl_wfull] + #metrics += [metric_base.ppl_zend] + #metrics += [metric_base.ppl_wend] + #metrics += [metric_base.ls] + #metrics += [metric_base.dummy] + + # Which networks to evaluate them on? + tasks = [] + tasks += [EasyDict(run_func_name='run_metrics.run_pickle', network_pkl='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', dataset_args=EasyDict(tfrecord_dir='ffhq', shuffle_mb=0), mirror_augment=True)] # karras2019stylegan-ffhq-1024x1024.pkl + #tasks += [EasyDict(run_func_name='run_metrics.run_snapshot', run_id=100, snapshot=25000)] + #tasks += [EasyDict(run_func_name='run_metrics.run_all_snapshots', run_id=100)] + + # How many GPUs to use? + submit_config.num_gpus = 1 + #submit_config.num_gpus = 2 + #submit_config.num_gpus = 4 + #submit_config.num_gpus = 8 + + # Execute. + submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) + submit_config.run_dir_ignore += config.run_dir_ignore + for task in tasks: + for metric in metrics: + submit_config.run_desc = '%s-%s' % (task.run_func_name, metric.name) + if task.run_func_name.endswith('run_snapshot'): + submit_config.run_desc += '-%s-%s' % (task.run_id, task.snapshot) + if task.run_func_name.endswith('run_all_snapshots'): + submit_config.run_desc += '-%s' % task.run_id + submit_config.run_desc += '-%dgpu' % submit_config.num_gpus + dnnlib.submit_run(submit_config, metric_args=metric, **task) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/train.py b/models/stylegan/stylegan_tf/train.py new file mode 100644 index 0000000000000000000000000000000000000000..29df3c226b87816ceec25752293df08a70d63189 --- /dev/null +++ b/models/stylegan/stylegan_tf/train.py @@ -0,0 +1,192 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Main entry point for training StyleGAN and ProGAN networks.""" + +import copy +import dnnlib +from dnnlib import EasyDict + +import config +from metrics import metric_base + +#---------------------------------------------------------------------------- +# Official training configs for StyleGAN, targeted mainly for FFHQ. + +if 1: + desc = 'sgan' # Description string included in result subdir name. + train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. + G = EasyDict(func_name='training.networks_stylegan.G_style') # Options for generator network. + D = EasyDict(func_name='training.networks_stylegan.D_basic') # Options for discriminator network. + G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. + D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. + G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating') # Options for generator loss. + D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # Options for discriminator loss. + dataset = EasyDict() # Options for load_dataset(). + sched = EasyDict() # Options for TrainingSchedule. + grid = EasyDict(size='4k', layout='random') # Options for setup_snapshot_image_grid(). + metrics = [metric_base.fid50k] # Options for MetricGroup. + submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). + tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). + + # Dataset. + desc += '-ffhq'; dataset = EasyDict(tfrecord_dir='ffhq'); train.mirror_augment = True + #desc += '-ffhq512'; dataset = EasyDict(tfrecord_dir='ffhq', resolution=512); train.mirror_augment = True + #desc += '-ffhq256'; dataset = EasyDict(tfrecord_dir='ffhq', resolution=256); train.mirror_augment = True + #desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True + #desc += '-bedroom'; dataset = EasyDict(tfrecord_dir='lsun-bedroom-full'); train.mirror_augment = False + #desc += '-car'; dataset = EasyDict(tfrecord_dir='lsun-car-512x384'); train.mirror_augment = False + #desc += '-cat'; dataset = EasyDict(tfrecord_dir='lsun-cat-full'); train.mirror_augment = False + + # Number of GPUs. + #desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4} + #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8} + #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16} + desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32} + + # Default options. + train.total_kimg = 25000 + sched.lod_initial_resolution = 8 + sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} + sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) + + # WGAN-GP loss for CelebA-HQ. + #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) + + # Table 1. + #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False + #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False + #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False + #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0 + #desc += '-mixing-regularization' # default + + # Table 2. + #desc += '-mix0'; G.style_mixing_prob = 0.0 + #desc += '-mix50'; G.style_mixing_prob = 0.5 + #desc += '-mix90'; G.style_mixing_prob = 0.9 # default + #desc += '-mix100'; G.style_mixing_prob = 1.0 + + # Table 4. + #desc += '-traditional-0'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False + #desc += '-traditional-8'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 8; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False + #desc += '-stylebased-0'; G.mapping_layers = 0 + #desc += '-stylebased-1'; G.mapping_layers = 1 + #desc += '-stylebased-2'; G.mapping_layers = 2 + #desc += '-stylebased-8'; G.mapping_layers = 8 # default + +#---------------------------------------------------------------------------- +# Official training configs for Progressive GAN, targeted mainly for CelebA-HQ. + +if 0: + desc = 'pgan' # Description string included in result subdir name. + train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. + G = EasyDict(func_name='training.networks_progan.G_paper') # Options for generator network. + D = EasyDict(func_name='training.networks_progan.D_paper') # Options for discriminator network. + G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. + D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. + G_loss = EasyDict(func_name='training.loss.G_wgan') # Options for generator loss. + D_loss = EasyDict(func_name='training.loss.D_wgan_gp') # Options for discriminator loss. + dataset = EasyDict() # Options for load_dataset(). + sched = EasyDict() # Options for TrainingSchedule. + grid = EasyDict(size='1080p', layout='random') # Options for setup_snapshot_image_grid(). + metrics = [metric_base.fid50k] # Options for MetricGroup. + submit_config = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). + tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). + + # Dataset (choose one). + desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True + #desc += '-celeba'; dataset = EasyDict(tfrecord_dir='celeba'); train.mirror_augment = True + #desc += '-cifar10'; dataset = EasyDict(tfrecord_dir='cifar10') + #desc += '-cifar100'; dataset = EasyDict(tfrecord_dir='cifar100') + #desc += '-svhn'; dataset = EasyDict(tfrecord_dir='svhn') + #desc += '-mnist'; dataset = EasyDict(tfrecord_dir='mnist') + #desc += '-mnistrgb'; dataset = EasyDict(tfrecord_dir='mnistrgb') + #desc += '-syn1024rgb'; dataset = EasyDict(class_name='training.dataset.SyntheticDataset', resolution=1024, num_channels=3) + #desc += '-lsun-airplane'; dataset = EasyDict(tfrecord_dir='lsun-airplane-100k'); train.mirror_augment = True + #desc += '-lsun-bedroom'; dataset = EasyDict(tfrecord_dir='lsun-bedroom-100k'); train.mirror_augment = True + #desc += '-lsun-bicycle'; dataset = EasyDict(tfrecord_dir='lsun-bicycle-100k'); train.mirror_augment = True + #desc += '-lsun-bird'; dataset = EasyDict(tfrecord_dir='lsun-bird-100k'); train.mirror_augment = True + #desc += '-lsun-boat'; dataset = EasyDict(tfrecord_dir='lsun-boat-100k'); train.mirror_augment = True + #desc += '-lsun-bottle'; dataset = EasyDict(tfrecord_dir='lsun-bottle-100k'); train.mirror_augment = True + #desc += '-lsun-bridge'; dataset = EasyDict(tfrecord_dir='lsun-bridge-100k'); train.mirror_augment = True + #desc += '-lsun-bus'; dataset = EasyDict(tfrecord_dir='lsun-bus-100k'); train.mirror_augment = True + #desc += '-lsun-car'; dataset = EasyDict(tfrecord_dir='lsun-car-100k'); train.mirror_augment = True + #desc += '-lsun-cat'; dataset = EasyDict(tfrecord_dir='lsun-cat-100k'); train.mirror_augment = True + #desc += '-lsun-chair'; dataset = EasyDict(tfrecord_dir='lsun-chair-100k'); train.mirror_augment = True + #desc += '-lsun-churchoutdoor'; dataset = EasyDict(tfrecord_dir='lsun-churchoutdoor-100k'); train.mirror_augment = True + #desc += '-lsun-classroom'; dataset = EasyDict(tfrecord_dir='lsun-classroom-100k'); train.mirror_augment = True + #desc += '-lsun-conferenceroom'; dataset = EasyDict(tfrecord_dir='lsun-conferenceroom-100k'); train.mirror_augment = True + #desc += '-lsun-cow'; dataset = EasyDict(tfrecord_dir='lsun-cow-100k'); train.mirror_augment = True + #desc += '-lsun-diningroom'; dataset = EasyDict(tfrecord_dir='lsun-diningroom-100k'); train.mirror_augment = True + #desc += '-lsun-diningtable'; dataset = EasyDict(tfrecord_dir='lsun-diningtable-100k'); train.mirror_augment = True + #desc += '-lsun-dog'; dataset = EasyDict(tfrecord_dir='lsun-dog-100k'); train.mirror_augment = True + #desc += '-lsun-horse'; dataset = EasyDict(tfrecord_dir='lsun-horse-100k'); train.mirror_augment = True + #desc += '-lsun-kitchen'; dataset = EasyDict(tfrecord_dir='lsun-kitchen-100k'); train.mirror_augment = True + #desc += '-lsun-livingroom'; dataset = EasyDict(tfrecord_dir='lsun-livingroom-100k'); train.mirror_augment = True + #desc += '-lsun-motorbike'; dataset = EasyDict(tfrecord_dir='lsun-motorbike-100k'); train.mirror_augment = True + #desc += '-lsun-person'; dataset = EasyDict(tfrecord_dir='lsun-person-100k'); train.mirror_augment = True + #desc += '-lsun-pottedplant'; dataset = EasyDict(tfrecord_dir='lsun-pottedplant-100k'); train.mirror_augment = True + #desc += '-lsun-restaurant'; dataset = EasyDict(tfrecord_dir='lsun-restaurant-100k'); train.mirror_augment = True + #desc += '-lsun-sheep'; dataset = EasyDict(tfrecord_dir='lsun-sheep-100k'); train.mirror_augment = True + #desc += '-lsun-sofa'; dataset = EasyDict(tfrecord_dir='lsun-sofa-100k'); train.mirror_augment = True + #desc += '-lsun-tower'; dataset = EasyDict(tfrecord_dir='lsun-tower-100k'); train.mirror_augment = True + #desc += '-lsun-train'; dataset = EasyDict(tfrecord_dir='lsun-train-100k'); train.mirror_augment = True + #desc += '-lsun-tvmonitor'; dataset = EasyDict(tfrecord_dir='lsun-tvmonitor-100k'); train.mirror_augment = True + + # Conditioning & snapshot options. + #desc += '-cond'; dataset.max_label_size = 'full' # conditioned on full label + #desc += '-cond1'; dataset.max_label_size = 1 # conditioned on first component of the label + #desc += '-g4k'; grid.size = '4k' + #desc += '-grpc'; grid.layout = 'row_per_class' + + # Config presets (choose one). + #desc += '-preset-v1-1gpu'; submit_config.num_gpus = 1; D.mbstd_group_size = 16; sched.minibatch_base = 16; sched.minibatch_dict = {256: 14, 512: 6, 1024: 3}; sched.lod_training_kimg = 800; sched.lod_transition_kimg = 800; train.total_kimg = 19000 + desc += '-preset-v2-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.G_lrate_dict = {1024: 0.0015}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 + #desc += '-preset-v2-2gpus'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}; sched.G_lrate_dict = {512: 0.0015, 1024: 0.002}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 + #desc += '-preset-v2-4gpus'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}; sched.G_lrate_dict = {256: 0.0015, 512: 0.002, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 + #desc += '-preset-v2-8gpus'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}; sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000 + + # Numerical precision (choose one). + desc += '-fp32'; sched.max_minibatch_per_gpu = {256: 16, 512: 8, 1024: 4} + #desc += '-fp16'; G.dtype = 'float16'; D.dtype = 'float16'; G.pixelnorm_epsilon=1e-4; G_opt.use_loss_scaling = True; D_opt.use_loss_scaling = True; sched.max_minibatch_per_gpu = {512: 16, 1024: 8} + + # Disable individual features. + #desc += '-nogrowing'; sched.lod_initial_resolution = 1024; sched.lod_training_kimg = 0; sched.lod_transition_kimg = 0; train.total_kimg = 10000 + #desc += '-nopixelnorm'; G.use_pixelnorm = False + #desc += '-nowscale'; G.use_wscale = False; D.use_wscale = False + #desc += '-noleakyrelu'; G.use_leakyrelu = False + #desc += '-nosmoothing'; train.G_smoothing_kimg = 0.0 + #desc += '-norepeat'; train.minibatch_repeats = 1 + #desc += '-noreset'; train.reset_opt_for_new_lod = False + + # Special modes. + #desc += '-BENCHMARK'; sched.lod_initial_resolution = 4; sched.lod_training_kimg = 3; sched.lod_transition_kimg = 3; train.total_kimg = (8*2+1)*3; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000 + #desc += '-BENCHMARK0'; sched.lod_initial_resolution = 1024; train.total_kimg = 10; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000 + #desc += '-VERBOSE'; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1; train.network_snapshot_ticks = 100 + #desc += '-GRAPH'; train.save_tf_graph = True + #desc += '-HIST'; train.save_weight_histograms = True + +#---------------------------------------------------------------------------- +# Main entry point for training. +# Calls the function indicated by 'train' using the selected options. + +def main(): + kwargs = EasyDict(train) + kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) + kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) + kwargs.submit_config = copy.deepcopy(submit_config) + kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir) + kwargs.submit_config.run_dir_ignore += config.run_dir_ignore + kwargs.submit_config.run_desc = desc + dnnlib.submit_run(**kwargs) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/training/__init__.py b/models/stylegan/stylegan_tf/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db8124b132f91216c0ded226f20ea3a046734728 --- /dev/null +++ b/models/stylegan/stylegan_tf/training/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +# empty diff --git a/models/stylegan/stylegan_tf/training/dataset.py b/models/stylegan/stylegan_tf/training/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cf142226b1794b675d61151467444cb65bdaa1a0 --- /dev/null +++ b/models/stylegan/stylegan_tf/training/dataset.py @@ -0,0 +1,241 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Multi-resolution input data pipeline.""" + +import os +import glob +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib + +#---------------------------------------------------------------------------- +# Parse individual image from a tfrecords file. + +def parse_tfrecord_tf(record): + features = tf.parse_single_example(record, features={ + 'shape': tf.FixedLenFeature([3], tf.int64), + 'data': tf.FixedLenFeature([], tf.string)}) + data = tf.decode_raw(features['data'], tf.uint8) + return tf.reshape(data, features['shape']) + +def parse_tfrecord_np(record): + ex = tf.train.Example() + ex.ParseFromString(record) + shape = ex.features.feature['shape'].int64_list.value # temporary pylint workaround # pylint: disable=no-member + data = ex.features.feature['data'].bytes_list.value[0] # temporary pylint workaround # pylint: disable=no-member + return np.fromstring(data, np.uint8).reshape(shape) + +#---------------------------------------------------------------------------- +# Dataset class that loads data from tfrecords files. + +class TFRecordDataset: + def __init__(self, + tfrecord_dir, # Directory containing a collection of tfrecords files. + resolution = None, # Dataset resolution, None = autodetect. + label_file = None, # Relative path of the labels file, None = autodetect. + max_label_size = 0, # 0 = no labels, 'full' = full labels, = N first label components. + repeat = True, # Repeat dataset indefinitely. + shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. + prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. + buffer_mb = 256, # Read buffer size (megabytes). + num_threads = 2): # Number of concurrent threads. + + self.tfrecord_dir = tfrecord_dir + self.resolution = None + self.resolution_log2 = None + self.shape = [] # [channel, height, width] + self.dtype = 'uint8' + self.dynamic_range = [0, 255] + self.label_file = label_file + self.label_size = None # [component] + self.label_dtype = None + self._np_labels = None + self._tf_minibatch_in = None + self._tf_labels_var = None + self._tf_labels_dataset = None + self._tf_datasets = dict() + self._tf_iterator = None + self._tf_init_ops = dict() + self._tf_minibatch_np = None + self._cur_minibatch = -1 + self._cur_lod = -1 + + # List tfrecords files and inspect their shapes. + assert os.path.isdir(self.tfrecord_dir) + tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords'))) + assert len(tfr_files) >= 1 + tfr_shapes = [] + for tfr_file in tfr_files: + tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE) + for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt): + tfr_shapes.append(parse_tfrecord_np(record).shape) + break + + # Autodetect label filename. + if self.label_file is None: + guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels'))) + if len(guess): + self.label_file = guess[0] + elif not os.path.isfile(self.label_file): + guess = os.path.join(self.tfrecord_dir, self.label_file) + if os.path.isfile(guess): + self.label_file = guess + + # Determine shape and resolution. + max_shape = max(tfr_shapes, key=np.prod) + self.resolution = resolution if resolution is not None else max_shape[1] + self.resolution_log2 = int(np.log2(self.resolution)) + self.shape = [max_shape[0], self.resolution, self.resolution] + tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes] + assert all(shape[0] == max_shape[0] for shape in tfr_shapes) + assert all(shape[1] == shape[2] for shape in tfr_shapes) + assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) + assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1)) + + # Load labels. + assert max_label_size == 'full' or max_label_size >= 0 + self._np_labels = np.zeros([1<<20, 0], dtype=np.float32) + if self.label_file is not None and max_label_size != 0: + self._np_labels = np.load(self.label_file) + assert self._np_labels.ndim == 2 + if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size: + self._np_labels = self._np_labels[:, :max_label_size] + self.label_size = self._np_labels.shape[1] + self.label_dtype = self._np_labels.dtype.name + + # Build TF expressions. + with tf.name_scope('Dataset'), tf.device('/cpu:0'): + self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[]) + self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var') + self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var) + for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): + if tfr_lod < 0: + continue + dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20) + dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads) + dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) + bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize + if shuffle_mb > 0: + dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) + if repeat: + dset = dset.repeat() + if prefetch_mb > 0: + dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) + dset = dset.batch(self._tf_minibatch_in) + self._tf_datasets[tfr_lod] = dset + self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) + self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()} + + # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf(). + def configure(self, minibatch_size, lod=0): + lod = int(np.floor(lod)) + assert minibatch_size >= 1 and lod in self._tf_datasets + if self._cur_minibatch != minibatch_size or self._cur_lod != lod: + self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size}) + self._cur_minibatch = minibatch_size + self._cur_lod = lod + + # Get next minibatch as TensorFlow expressions. + def get_minibatch_tf(self): # => images, labels + return self._tf_iterator.get_next() + + # Get next minibatch as NumPy arrays. + def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels + self.configure(minibatch_size, lod) + if self._tf_minibatch_np is None: + self._tf_minibatch_np = self.get_minibatch_tf() + return tflib.run(self._tf_minibatch_np) + + # Get random labels as TensorFlow expression. + def get_random_labels_tf(self, minibatch_size): # => labels + if self.label_size > 0: + with tf.device('/cpu:0'): + return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32)) + return tf.zeros([minibatch_size, 0], self.label_dtype) + + # Get random labels as NumPy array. + def get_random_labels_np(self, minibatch_size): # => labels + if self.label_size > 0: + return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])] + return np.zeros([minibatch_size, 0], self.label_dtype) + +#---------------------------------------------------------------------------- +# Base class for datasets that are generated on the fly. + +class SyntheticDataset: + def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dynamic_range=[0,255], label_size=0, label_dtype='float32'): + self.resolution = resolution + self.resolution_log2 = int(np.log2(resolution)) + self.shape = [num_channels, resolution, resolution] + self.dtype = dtype + self.dynamic_range = dynamic_range + self.label_size = label_size + self.label_dtype = label_dtype + self._tf_minibatch_var = None + self._tf_lod_var = None + self._tf_minibatch_np = None + self._tf_labels_np = None + + assert self.resolution == 2 ** self.resolution_log2 + with tf.name_scope('Dataset'): + self._tf_minibatch_var = tf.Variable(np.int32(0), name='minibatch_var') + self._tf_lod_var = tf.Variable(np.int32(0), name='lod_var') + + def configure(self, minibatch_size, lod=0): + lod = int(np.floor(lod)) + assert minibatch_size >= 1 and 0 <= lod <= self.resolution_log2 + tflib.set_vars({self._tf_minibatch_var: minibatch_size, self._tf_lod_var: lod}) + + def get_minibatch_tf(self): # => images, labels + with tf.name_scope('SyntheticDataset'): + shrink = tf.cast(2.0 ** tf.cast(self._tf_lod_var, tf.float32), tf.int32) + shape = [self.shape[0], self.shape[1] // shrink, self.shape[2] // shrink] + images = self._generate_images(self._tf_minibatch_var, self._tf_lod_var, shape) + labels = self._generate_labels(self._tf_minibatch_var) + return images, labels + + def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels + self.configure(minibatch_size, lod) + if self._tf_minibatch_np is None: + self._tf_minibatch_np = self.get_minibatch_tf() + return tflib.run(self._tf_minibatch_np) + + def get_random_labels_tf(self, minibatch_size): # => labels + with tf.name_scope('SyntheticDataset'): + return self._generate_labels(minibatch_size) + + def get_random_labels_np(self, minibatch_size): # => labels + self.configure(minibatch_size) + if self._tf_labels_np is None: + self._tf_labels_np = self.get_random_labels_tf(minibatch_size) + return tflib.run(self._tf_labels_np) + + def _generate_images(self, minibatch, lod, shape): # to be overridden by subclasses # pylint: disable=unused-argument + return tf.zeros([minibatch] + shape, self.dtype) + + def _generate_labels(self, minibatch): # to be overridden by subclasses + return tf.zeros([minibatch, self.label_size], self.label_dtype) + +#---------------------------------------------------------------------------- +# Helper func for constructing a dataset object using the given options. + +def load_dataset(class_name='training.dataset.TFRecordDataset', data_dir=None, verbose=False, **kwargs): + adjusted_kwargs = dict(kwargs) + if 'tfrecord_dir' in adjusted_kwargs and data_dir is not None: + adjusted_kwargs['tfrecord_dir'] = os.path.join(data_dir, adjusted_kwargs['tfrecord_dir']) + if verbose: + print('Streaming data using %s...' % class_name) + dataset = dnnlib.util.get_obj_by_name(class_name)(**adjusted_kwargs) + if verbose: + print('Dataset shape =', np.int32(dataset.shape).tolist()) + print('Dynamic range =', dataset.dynamic_range) + print('Label size =', dataset.label_size) + return dataset + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/training/loss.py b/models/stylegan/stylegan_tf/training/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..aa59b61bf316f73f269849b54ec3bb35b6a0d61d --- /dev/null +++ b/models/stylegan/stylegan_tf/training/loss.py @@ -0,0 +1,177 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Loss functions.""" + +import tensorflow as tf +import dnnlib.tflib as tflib +from dnnlib.tflib.autosummary import autosummary + +#---------------------------------------------------------------------------- +# Convenience func that casts all of its arguments to tf.float32. + +def fp32(*values): + if len(values) == 1 and isinstance(values[0], tuple): + values = values[0] + values = tuple(tf.cast(v, tf.float32) for v in values) + return values if len(values) >= 2 else values[0] + +#---------------------------------------------------------------------------- +# WGAN & WGAN-GP loss functions. + +def G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + labels = training_set.get_random_labels_tf(minibatch_size) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + loss = -fake_scores_out + return loss + +def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument + wgan_epsilon = 0.001): # Weight for the epsilon term, \epsilon_{drift}. + + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = fake_scores_out - real_scores_out + + with tf.name_scope('EpsilonPenalty'): + epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) + loss += epsilon_penalty * wgan_epsilon + return loss + +def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument + wgan_lambda = 10.0, # Weight for the gradient penalty term. + wgan_epsilon = 0.001, # Weight for the epsilon term, \epsilon_{drift}. + wgan_target = 1.0): # Target value for gradient magnitudes. + + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = fake_scores_out - real_scores_out + + with tf.name_scope('GradientPenalty'): + mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) + mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) + mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) + mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) + mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) + mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) + mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) + mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) + gradient_penalty = tf.square(mixed_norms - wgan_target) + loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) + + with tf.name_scope('EpsilonPenalty'): + epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) + loss += epsilon_penalty * wgan_epsilon + return loss + +#---------------------------------------------------------------------------- +# Hinge loss functions. (Use G_wgan with these) + +def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) + return loss + +def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument + wgan_lambda = 10.0, # Weight for the gradient penalty term. + wgan_target = 1.0): # Target value for gradient magnitudes. + + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out) + + with tf.name_scope('GradientPenalty'): + mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) + mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) + mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True)) + mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) + mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) + mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) + mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3])) + mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) + gradient_penalty = tf.square(mixed_norms - wgan_target) + loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) + return loss + + +#---------------------------------------------------------------------------- +# Loss functions advocated by the paper +# "Which Training Methods for GANs do actually Converge?" + +def G_logistic_saturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + labels = training_set.get_random_labels_tf(minibatch_size) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + loss = -tf.nn.softplus(fake_scores_out) # log(1 - logistic(fake_scores_out)) + return loss + +def G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + labels = training_set.get_random_labels_tf(minibatch_size) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + loss = tf.nn.softplus(-fake_scores_out) # -log(logistic(fake_scores_out)) + return loss + +def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) + loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type + return loss + +def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument + latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) + fake_images_out = G.get_output_for(latents, labels, is_training=True) + real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) + fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True)) + real_scores_out = autosummary('Loss/scores/real', real_scores_out) + fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) + loss = tf.nn.softplus(fake_scores_out) # -log(1 - logistic(fake_scores_out)) + loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type + + if r1_gamma != 0.0: + with tf.name_scope('R1Penalty'): + real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out)) + real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0])) + r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3]) + r1_penalty = autosummary('Loss/r1_penalty', r1_penalty) + loss += r1_penalty * (r1_gamma * 0.5) + + if r2_gamma != 0.0: + with tf.name_scope('R2Penalty'): + fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out)) + fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0])) + r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3]) + r2_penalty = autosummary('Loss/r2_penalty', r2_penalty) + loss += r2_penalty * (r2_gamma * 0.5) + return loss + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/training/misc.py b/models/stylegan/stylegan_tf/training/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..50ae51c722cb1e553c56051cbd4556110fe4a1f9 --- /dev/null +++ b/models/stylegan/stylegan_tf/training/misc.py @@ -0,0 +1,245 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Miscellaneous utility functions.""" + +import os +import glob +import pickle +import re +import numpy as np +from collections import defaultdict +import PIL.Image +import dnnlib + +import config +from training import dataset + +#---------------------------------------------------------------------------- +# Convenience wrappers for pickle that are able to load data produced by +# older versions of the code, and from external URLs. + +def open_file_or_url(file_or_url): + if dnnlib.util.is_url(file_or_url): + return dnnlib.util.open_url(file_or_url, cache_dir=config.cache_dir) + return open(file_or_url, 'rb') + +def load_pkl(file_or_url): + with open_file_or_url(file_or_url) as file: + return pickle.load(file, encoding='latin1') + +def save_pkl(obj, filename): + with open(filename, 'wb') as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) + +#---------------------------------------------------------------------------- +# Image utils. + +def adjust_dynamic_range(data, drange_in, drange_out): + if drange_in != drange_out: + scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) + bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) + data = data * scale + bias + return data + +def create_image_grid(images, grid_size=None): + assert images.ndim == 3 or images.ndim == 4 + num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] + + if grid_size is not None: + grid_w, grid_h = tuple(grid_size) + else: + grid_w = max(int(np.ceil(np.sqrt(num))), 1) + grid_h = max((num - 1) // grid_w + 1, 1) + + grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) + for idx in range(num): + x = (idx % grid_w) * img_w + y = (idx // grid_w) * img_h + grid[..., y : y + img_h, x : x + img_w] = images[idx] + return grid + +def convert_to_pil_image(image, drange=[0,1]): + assert image.ndim == 2 or image.ndim == 3 + if image.ndim == 3: + if image.shape[0] == 1: + image = image[0] # grayscale CHW => HW + else: + image = image.transpose(1, 2, 0) # CHW -> HWC + + image = adjust_dynamic_range(image, drange, [0,255]) + image = np.rint(image).clip(0, 255).astype(np.uint8) + fmt = 'RGB' if image.ndim == 3 else 'L' + return PIL.Image.fromarray(image, fmt) + +def save_image(image, filename, drange=[0,1], quality=95): + img = convert_to_pil_image(image, drange) + if '.jpg' in filename: + img.save(filename,"JPEG", quality=quality, optimize=True) + else: + img.save(filename) + +def save_image_grid(images, filename, drange=[0,1], grid_size=None): + convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) + +#---------------------------------------------------------------------------- +# Locating results. + +def locate_run_dir(run_id_or_run_dir): + if isinstance(run_id_or_run_dir, str): + if os.path.isdir(run_id_or_run_dir): + return run_id_or_run_dir + converted = dnnlib.submission.submit.convert_path(run_id_or_run_dir) + if os.path.isdir(converted): + return converted + + run_dir_pattern = re.compile('^0*%s-' % str(run_id_or_run_dir)) + for search_dir in ['']: + full_search_dir = config.result_dir if search_dir == '' else os.path.normpath(os.path.join(config.result_dir, search_dir)) + run_dir = os.path.join(full_search_dir, str(run_id_or_run_dir)) + if os.path.isdir(run_dir): + return run_dir + run_dirs = sorted(glob.glob(os.path.join(full_search_dir, '*'))) + run_dirs = [run_dir for run_dir in run_dirs if run_dir_pattern.match(os.path.basename(run_dir))] + run_dirs = [run_dir for run_dir in run_dirs if os.path.isdir(run_dir)] + if len(run_dirs) == 1: + return run_dirs[0] + raise IOError('Cannot locate result subdir for run', run_id_or_run_dir) + +def list_network_pkls(run_id_or_run_dir, include_final=True): + run_dir = locate_run_dir(run_id_or_run_dir) + pkls = sorted(glob.glob(os.path.join(run_dir, 'network-*.pkl'))) + if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': + if include_final: + pkls.append(pkls[0]) + del pkls[0] + return pkls + +def locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): + for candidate in [snapshot_or_network_pkl, run_id_or_run_dir_or_network_pkl]: + if isinstance(candidate, str): + if os.path.isfile(candidate): + return candidate + converted = dnnlib.submission.submit.convert_path(candidate) + if os.path.isfile(converted): + return converted + + pkls = list_network_pkls(run_id_or_run_dir_or_network_pkl) + if len(pkls) >= 1 and snapshot_or_network_pkl is None: + return pkls[-1] + + for pkl in pkls: + try: + name = os.path.splitext(os.path.basename(pkl))[0] + number = int(name.split('-')[-1]) + if number == snapshot_or_network_pkl: + return pkl + except ValueError: pass + except IndexError: pass + raise IOError('Cannot locate network pkl for snapshot', snapshot_or_network_pkl) + +def get_id_string_for_network_pkl(network_pkl): + p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') + return '-'.join(p[max(len(p) - 2, 0):]) + +#---------------------------------------------------------------------------- +# Loading data from previous training runs. + +def load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None): + return load_pkl(locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl)) + +def parse_config_for_previous_run(run_id): + run_dir = locate_run_dir(run_id) + + # Parse config.txt. + cfg = defaultdict(dict) + with open(os.path.join(run_dir, 'config.txt'), 'rt') as f: + for line in f: + line = re.sub(r"^{?\s*'(\w+)':\s*{(.*)(},|}})$", r"\1 = {\2}", line.strip()) + if line.startswith('dataset =') or line.startswith('train ='): + exec(line, cfg, cfg) # pylint: disable=exec-used + + # Handle legacy options. + if 'file_pattern' in cfg['dataset']: + cfg['dataset']['tfrecord_dir'] = cfg['dataset'].pop('file_pattern').replace('-r??.tfrecords', '') + if 'mirror_augment' in cfg['dataset']: + cfg['train']['mirror_augment'] = cfg['dataset'].pop('mirror_augment') + if 'max_labels' in cfg['dataset']: + v = cfg['dataset'].pop('max_labels') + if v is None: v = 0 + if v == 'all': v = 'full' + cfg['dataset']['max_label_size'] = v + if 'max_images' in cfg['dataset']: + cfg['dataset'].pop('max_images') + return cfg + +def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment + cfg = parse_config_for_previous_run(run_id) + cfg['dataset'].update(kwargs) + dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **cfg['dataset']) + mirror_augment = cfg['train'].get('mirror_augment', False) + return dataset_obj, mirror_augment + +def apply_mirror_augment(minibatch): + mask = np.random.rand(minibatch.shape[0]) < 0.5 + minibatch = np.array(minibatch) + minibatch[mask] = minibatch[mask, :, :, ::-1] + return minibatch + +#---------------------------------------------------------------------------- +# Size and contents of the image snapshot grids that are exported +# periodically during training. + +def setup_snapshot_image_grid(G, training_set, + size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. + layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. + + # Select size. + gw = 1; gh = 1 + if size == '1080p': + gw = np.clip(1920 // G.output_shape[3], 3, 32) + gh = np.clip(1080 // G.output_shape[2], 2, 32) + if size == '4k': + gw = np.clip(3840 // G.output_shape[3], 7, 32) + gh = np.clip(2160 // G.output_shape[2], 4, 32) + + # Initialize data arrays. + reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) + labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) + latents = np.random.randn(gw * gh, *G.input_shape[1:]) + + # Random layout. + if layout == 'random': + reals[:], labels[:] = training_set.get_minibatch_np(gw * gh) + + # Class-conditional layouts. + class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4]) + if layout in class_layouts: + bw, bh = class_layouts[layout] + nw = (gw - 1) // bw + 1 + nh = (gh - 1) // bh + 1 + blocks = [[] for _i in range(nw * nh)] + for _iter in range(1000000): + real, label = training_set.get_minibatch_np(1) + idx = np.argmax(label[0]) + while idx < len(blocks) and len(blocks[idx]) >= bw * bh: + idx += training_set.label_size + if idx < len(blocks): + blocks[idx].append((real, label)) + if all(len(block) >= bw * bh for block in blocks): + break + for i, block in enumerate(blocks): + for j, (real, label) in enumerate(block): + x = (i % nw) * bw + j % bw + y = (i // nw) * bh + j // bw + if x < gw and y < gh: + reals[x + y * gw] = real[0] + labels[x + y * gw] = label[0] + + return (gw, gh), reals, labels, latents + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/training/networks_progan.py b/models/stylegan/stylegan_tf/training/networks_progan.py new file mode 100644 index 0000000000000000000000000000000000000000..896f500b0bfca5c292b1cba8de79e270f6a08036 --- /dev/null +++ b/models/stylegan/stylegan_tf/training/networks_progan.py @@ -0,0 +1,322 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Network architectures used in the ProGAN paper.""" + +import numpy as np +import tensorflow as tf + +# NOTE: Do not import any application-specific modules here! +# Specify all network parameters as kwargs. + +#---------------------------------------------------------------------------- + +def lerp(a, b, t): return a + (b - a) * t +def lerp_clip(a, b, t): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) +def cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_cond, new_lambda, cur_lambda) + +#---------------------------------------------------------------------------- +# Get/create weight tensor for a convolutional or fully-connected layer. + +def get_weight(shape, gain=np.sqrt(2), use_wscale=False): + fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] + std = gain / np.sqrt(fan_in) # He init + if use_wscale: + wscale = tf.constant(np.float32(std), name='wscale') + w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale + else: + w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std)) + return w + +#---------------------------------------------------------------------------- +# Fully-connected layer. + +def dense(x, fmaps, gain=np.sqrt(2), use_wscale=False): + if len(x.shape) > 2: + x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) + w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) + w = tf.cast(w, x.dtype) + return tf.matmul(x, w) + +#---------------------------------------------------------------------------- +# Convolutional layer. + +def conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): + assert kernel >= 1 and kernel % 2 == 1 + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) + w = tf.cast(w, x.dtype) + return tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW') + +#---------------------------------------------------------------------------- +# Apply bias to the given activation tensor. + +def apply_bias(x): + b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros()) + b = tf.cast(b, x.dtype) + if len(x.shape) == 2: + return x + b + return x + tf.reshape(b, [1, -1, 1, 1]) + +#---------------------------------------------------------------------------- +# Leaky ReLU activation. Same as tf.nn.leaky_relu, but supports FP16. + +def leaky_relu(x, alpha=0.2): + with tf.name_scope('LeakyRelu'): + alpha = tf.constant(alpha, dtype=x.dtype, name='alpha') + return tf.maximum(x * alpha, x) + +#---------------------------------------------------------------------------- +# Nearest-neighbor upscaling layer. + +def upscale2d(x, factor=2): + assert isinstance(factor, int) and factor >= 1 + if factor == 1: return x + with tf.variable_scope('Upscale2D'): + s = x.shape + x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) + x = tf.tile(x, [1, 1, 1, factor, 1, factor]) + x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) + return x + +#---------------------------------------------------------------------------- +# Fused upscale2d + conv2d. +# Faster and uses less memory than performing the operations separately. + +def upscale2d_conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): + assert kernel >= 1 and kernel % 2 == 1 + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) + w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in] + w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) + w = tf.cast(w, x.dtype) + os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2] + return tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW') + +#---------------------------------------------------------------------------- +# Box filter downscaling layer. + +def downscale2d(x, factor=2): + assert isinstance(factor, int) and factor >= 1 + if factor == 1: return x + with tf.variable_scope('Downscale2D'): + ksize = [1, 1, factor, factor] + return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') # NOTE: requires tf_config['graph_options.place_pruned_graph'] = True + +#---------------------------------------------------------------------------- +# Fused conv2d + downscale2d. +# Faster and uses less memory than performing the operations separately. + +def conv2d_downscale2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False): + assert kernel >= 1 and kernel % 2 == 1 + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale) + w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25 + w = tf.cast(w, x.dtype) + return tf.nn.conv2d(x, w, strides=[1,1,2,2], padding='SAME', data_format='NCHW') + +#---------------------------------------------------------------------------- +# Pixelwise feature vector normalization. + +def pixel_norm(x, epsilon=1e-8): + with tf.variable_scope('PixelNorm'): + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon) + +#---------------------------------------------------------------------------- +# Minibatch standard deviation. + +def minibatch_stddev_layer(x, group_size=4, num_new_features=1): + with tf.variable_scope('MinibatchStddev'): + group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. + s = x.shape # [NCHW] Input shape. + y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c. + y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32. + y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group. + y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group. + y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group. + y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels. + y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups + y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type. + y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels. + return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. + +#---------------------------------------------------------------------------- +# Networks used in the ProgressiveGAN paper. + +def G_paper( + latents_in, # First input: Latent vectors [minibatch, latent_size]. + labels_in, # Second input: Labels [minibatch, label_size]. + num_channels = 1, # Number of output color channels. Overridden based on dataset. + resolution = 32, # Output resolution. Overridden based on dataset. + label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. + fmap_base = 8192, # Overall multiplier for the number of feature maps. + fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. + fmap_max = 512, # Maximum number of feature maps in any layer. + latent_size = None, # Dimensionality of the latent vectors. None = min(fmap_base, fmap_max). + normalize_latents = True, # Normalize latent vectors before feeding them to the network? + use_wscale = True, # Enable equalized learning rate? + use_pixelnorm = True, # Enable pixelwise feature vector normalization? + pixelnorm_epsilon = 1e-8, # Constant epsilon for pixelwise feature vector normalization. + use_leakyrelu = True, # True = leaky ReLU, False = ReLU. + dtype = 'float32', # Data type to use for activations and outputs. + fused_scale = True, # True = use fused upscale2d + conv2d, False = separate upscale2d layers. + structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically. + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + **_kwargs): # Ignore unrecognized keyword args. + + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2**resolution_log2 and resolution >= 4 + def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + def PN(x): return pixel_norm(x, epsilon=pixelnorm_epsilon) if use_pixelnorm else x + if latent_size is None: latent_size = nf(0) + if structure is None: structure = 'linear' if is_template_graph else 'recursive' + act = leaky_relu if use_leakyrelu else tf.nn.relu + + latents_in.set_shape([None, latent_size]) + labels_in.set_shape([None, label_size]) + combo_in = tf.cast(tf.concat([latents_in, labels_in], axis=1), dtype) + lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) + images_out = None + + # Building blocks. + def block(x, res): # res = 2..resolution_log2 + with tf.variable_scope('%dx%d' % (2**res, 2**res)): + if res == 2: # 4x4 + if normalize_latents: x = pixel_norm(x, epsilon=pixelnorm_epsilon) + with tf.variable_scope('Dense'): + x = dense(x, fmaps=nf(res-1)*16, gain=np.sqrt(2)/4, use_wscale=use_wscale) # override gain to match the original Theano implementation + x = tf.reshape(x, [-1, nf(res-1), 4, 4]) + x = PN(act(apply_bias(x))) + with tf.variable_scope('Conv'): + x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) + else: # 8x8 and up + if fused_scale: + with tf.variable_scope('Conv0_up'): + x = PN(act(apply_bias(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) + else: + x = upscale2d(x) + with tf.variable_scope('Conv0'): + x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) + with tf.variable_scope('Conv1'): + x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))) + return x + def torgb(x, res): # res = 2..resolution_log2 + lod = resolution_log2 - res + with tf.variable_scope('ToRGB_lod%d' % lod): + return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale)) + + # Linear structure: simple but inefficient. + if structure == 'linear': + x = block(combo_in, 2) + images_out = torgb(x, 2) + for res in range(3, resolution_log2 + 1): + lod = resolution_log2 - res + x = block(x, res) + img = torgb(x, res) + images_out = upscale2d(images_out) + with tf.variable_scope('Grow_lod%d' % lod): + images_out = lerp_clip(img, images_out, lod_in - lod) + + # Recursive structure: complex but efficient. + if structure == 'recursive': + def grow(x, res, lod): + y = block(x, res) + img = lambda: upscale2d(torgb(y, res), 2**lod) + if res > 2: img = cset(img, (lod_in > lod), lambda: upscale2d(lerp(torgb(y, res), upscale2d(torgb(x, res - 1)), lod_in - lod), 2**lod)) + if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1)) + return img() + images_out = grow(combo_in, 2, resolution_log2 - 2) + + assert images_out.dtype == tf.as_dtype(dtype) + images_out = tf.identity(images_out, name='images_out') + return images_out + + +def D_paper( + images_in, # First input: Images [minibatch, channel, height, width]. + labels_in, # Second input: Labels [minibatch, label_size]. + num_channels = 1, # Number of input color channels. Overridden based on dataset. + resolution = 32, # Input resolution. Overridden based on dataset. + label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. + fmap_base = 8192, # Overall multiplier for the number of feature maps. + fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. + fmap_max = 512, # Maximum number of feature maps in any layer. + use_wscale = True, # Enable equalized learning rate? + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable. + dtype = 'float32', # Data type to use for activations and outputs. + fused_scale = True, # True = use fused conv2d + downscale2d, False = separate downscale2d layers. + structure = None, # 'linear' = human-readable, 'recursive' = efficient, None = select automatically + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + **_kwargs): # Ignore unrecognized keyword args. + + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2**resolution_log2 and resolution >= 4 + def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + if structure is None: structure = 'linear' if is_template_graph else 'recursive' + act = leaky_relu + + images_in.set_shape([None, num_channels, resolution, resolution]) + labels_in.set_shape([None, label_size]) + images_in = tf.cast(images_in, dtype) + labels_in = tf.cast(labels_in, dtype) + lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) + scores_out = None + + # Building blocks. + def fromrgb(x, res): # res = 2..resolution_log2 + with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)): + return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, use_wscale=use_wscale))) + def block(x, res): # res = 2..resolution_log2 + with tf.variable_scope('%dx%d' % (2**res, 2**res)): + if res >= 3: # 8x8 and up + with tf.variable_scope('Conv0'): + x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) + if fused_scale: + with tf.variable_scope('Conv1_down'): + x = act(apply_bias(conv2d_downscale2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) + else: + with tf.variable_scope('Conv1'): + x = act(apply_bias(conv2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale))) + x = downscale2d(x) + else: # 4x4 + if mbstd_group_size > 1: + x = minibatch_stddev_layer(x, mbstd_group_size) + with tf.variable_scope('Conv'): + x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))) + with tf.variable_scope('Dense0'): + x = act(apply_bias(dense(x, fmaps=nf(res-2), use_wscale=use_wscale))) + with tf.variable_scope('Dense1'): + x = apply_bias(dense(x, fmaps=1, gain=1, use_wscale=use_wscale)) + return x + + # Linear structure: simple but inefficient. + if structure == 'linear': + img = images_in + x = fromrgb(img, resolution_log2) + for res in range(resolution_log2, 2, -1): + lod = resolution_log2 - res + x = block(x, res) + img = downscale2d(img) + y = fromrgb(img, res - 1) + with tf.variable_scope('Grow_lod%d' % lod): + x = lerp_clip(x, y, lod_in - lod) + scores_out = block(x, 2) + + # Recursive structure: complex but efficient. + if structure == 'recursive': + def grow(res, lod): + x = lambda: fromrgb(downscale2d(images_in, 2**lod), res) + if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1)) + x = block(x(), res); y = lambda: x + if res > 2: y = cset(y, (lod_in > lod), lambda: lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod)) + return y() + scores_out = grow(2, resolution_log2 - 2) + + assert scores_out.dtype == tf.as_dtype(dtype) + scores_out = tf.identity(scores_out, name='scores_out') + return scores_out + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/training/networks_stylegan.py b/models/stylegan/stylegan_tf/training/networks_stylegan.py new file mode 100644 index 0000000000000000000000000000000000000000..adc4b260f6f94570c793b0086280f757d2e19ad1 --- /dev/null +++ b/models/stylegan/stylegan_tf/training/networks_stylegan.py @@ -0,0 +1,661 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Network architectures used in the StyleGAN paper.""" + +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib + +# NOTE: Do not import any application-specific modules here! +# Specify all network parameters as kwargs. + +#---------------------------------------------------------------------------- +# Primitive ops for manipulating 4D activation tensors. +# The gradients of these are not necessary efficient or even meaningful. + +def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1): + assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:]) + assert isinstance(stride, int) and stride >= 1 + + # Finalize filter kernel. + f = np.array(f, dtype=np.float32) + if f.ndim == 1: + f = f[:, np.newaxis] * f[np.newaxis, :] + assert f.ndim == 2 + if normalize: + f /= np.sum(f) + if flip: + f = f[::-1, ::-1] + f = f[:, :, np.newaxis, np.newaxis] + f = np.tile(f, [1, 1, int(x.shape[1]), 1]) + + # No-op => early exit. + if f.shape == (1, 1) and f[0,0] == 1: + return x + + # Convolve using depthwise_conv2d. + orig_dtype = x.dtype + x = tf.cast(x, tf.float32) # tf.nn.depthwise_conv2d() doesn't support fp16 + f = tf.constant(f, dtype=x.dtype, name='filter') + strides = [1, 1, stride, stride] + x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW') + x = tf.cast(x, orig_dtype) + return x + +def _upscale2d(x, factor=2, gain=1): + assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:]) + assert isinstance(factor, int) and factor >= 1 + + # Apply gain. + if gain != 1: + x *= gain + + # No-op => early exit. + if factor == 1: + return x + + # Upscale using tf.tile(). + s = x.shape + x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) + x = tf.tile(x, [1, 1, 1, factor, 1, factor]) + x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) + return x + +def _downscale2d(x, factor=2, gain=1): + assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:]) + assert isinstance(factor, int) and factor >= 1 + + # 2x2, float32 => downscale using _blur2d(). + if factor == 2 and x.dtype == tf.float32: + f = [np.sqrt(gain) / factor] * factor + return _blur2d(x, f=f, normalize=False, stride=factor) + + # Apply gain. + if gain != 1: + x *= gain + + # No-op => early exit. + if factor == 1: + return x + + # Large factor => downscale using tf.nn.avg_pool(). + # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work. + ksize = [1, 1, factor, factor] + return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') + +#---------------------------------------------------------------------------- +# High-level ops for manipulating 4D activation tensors. +# The gradients of these are meant to be as efficient as possible. + +def blur2d(x, f=[1,2,1], normalize=True): + with tf.variable_scope('Blur2D'): + @tf.custom_gradient + def func(x): + y = _blur2d(x, f, normalize) + @tf.custom_gradient + def grad(dy): + dx = _blur2d(dy, f, normalize, flip=True) + return dx, lambda ddx: _blur2d(ddx, f, normalize) + return y, grad + return func(x) + +def upscale2d(x, factor=2): + with tf.variable_scope('Upscale2D'): + @tf.custom_gradient + def func(x): + y = _upscale2d(x, factor) + @tf.custom_gradient + def grad(dy): + dx = _downscale2d(dy, factor, gain=factor**2) + return dx, lambda ddx: _upscale2d(ddx, factor) + return y, grad + return func(x) + +def downscale2d(x, factor=2): + with tf.variable_scope('Downscale2D'): + @tf.custom_gradient + def func(x): + y = _downscale2d(x, factor) + @tf.custom_gradient + def grad(dy): + dx = _upscale2d(dy, factor, gain=1/factor**2) + return dx, lambda ddx: _downscale2d(ddx, factor) + return y, grad + return func(x) + +#---------------------------------------------------------------------------- +# Get/create weight tensor for a convolutional or fully-connected layer. + +def get_weight(shape, gain=np.sqrt(2), use_wscale=False, lrmul=1): + fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] + he_std = gain / np.sqrt(fan_in) # He init + + # Equalized learning rate and custom learning rate multiplier. + if use_wscale: + init_std = 1.0 / lrmul + runtime_coef = he_std * lrmul + else: + init_std = he_std / lrmul + runtime_coef = lrmul + + # Create variable. + init = tf.initializers.random_normal(0, init_std) + return tf.get_variable('weight', shape=shape, initializer=init) * runtime_coef + +#---------------------------------------------------------------------------- +# Fully-connected layer. + +def dense(x, fmaps, **kwargs): + if len(x.shape) > 2: + x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) + w = get_weight([x.shape[1].value, fmaps], **kwargs) + w = tf.cast(w, x.dtype) + return tf.matmul(x, w) + +#---------------------------------------------------------------------------- +# Convolutional layer. + +def conv2d(x, fmaps, kernel, **kwargs): + assert kernel >= 1 and kernel % 2 == 1 + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs) + w = tf.cast(w, x.dtype) + return tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW') + +#---------------------------------------------------------------------------- +# Fused convolution + scaling. +# Faster and uses less memory than performing the operations separately. + +def upscale2d_conv2d(x, fmaps, kernel, fused_scale='auto', **kwargs): + assert kernel >= 1 and kernel % 2 == 1 + assert fused_scale in [True, False, 'auto'] + if fused_scale == 'auto': + fused_scale = min(x.shape[2:]) * 2 >= 128 + + # Not fused => call the individual ops directly. + if not fused_scale: + return conv2d(upscale2d(x), fmaps, kernel, **kwargs) + + # Fused => perform both ops simultaneously using tf.nn.conv2d_transpose(). + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs) + w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in] + w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) + w = tf.cast(w, x.dtype) + os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2] + return tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW') + +def conv2d_downscale2d(x, fmaps, kernel, fused_scale='auto', **kwargs): + assert kernel >= 1 and kernel % 2 == 1 + assert fused_scale in [True, False, 'auto'] + if fused_scale == 'auto': + fused_scale = min(x.shape[2:]) >= 128 + + # Not fused => call the individual ops directly. + if not fused_scale: + return downscale2d(conv2d(x, fmaps, kernel, **kwargs)) + + # Fused => perform both ops simultaneously using tf.nn.conv2d(). + w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs) + w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT') + w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25 + w = tf.cast(w, x.dtype) + return tf.nn.conv2d(x, w, strides=[1,1,2,2], padding='SAME', data_format='NCHW') + +#---------------------------------------------------------------------------- +# Apply bias to the given activation tensor. + +def apply_bias(x, lrmul=1): + b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros()) * lrmul + b = tf.cast(b, x.dtype) + if len(x.shape) == 2: + return x + b + return x + tf.reshape(b, [1, -1, 1, 1]) + +#---------------------------------------------------------------------------- +# Leaky ReLU activation. More efficient than tf.nn.leaky_relu() and supports FP16. + +def leaky_relu(x, alpha=0.2): + with tf.variable_scope('LeakyReLU'): + alpha = tf.constant(alpha, dtype=x.dtype, name='alpha') + @tf.custom_gradient + def func(x): + y = tf.maximum(x, x * alpha) + @tf.custom_gradient + def grad(dy): + dx = tf.where(y >= 0, dy, dy * alpha) + return dx, lambda ddx: tf.where(y >= 0, ddx, ddx * alpha) + return y, grad + return func(x) + +#---------------------------------------------------------------------------- +# Pixelwise feature vector normalization. + +def pixel_norm(x, epsilon=1e-8): + with tf.variable_scope('PixelNorm'): + epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon') + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon) + +#---------------------------------------------------------------------------- +# Instance normalization. + +def instance_norm(x, epsilon=1e-8): + assert len(x.shape) == 4 # NCHW + with tf.variable_scope('InstanceNorm'): + orig_dtype = x.dtype + x = tf.cast(x, tf.float32) + x -= tf.reduce_mean(x, axis=[2,3], keepdims=True) + epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon') + x *= tf.rsqrt(tf.reduce_mean(tf.square(x), axis=[2,3], keepdims=True) + epsilon) + x = tf.cast(x, orig_dtype) + return x + +#---------------------------------------------------------------------------- +# Style modulation. + +def style_mod(x, dlatent, **kwargs): + with tf.variable_scope('StyleMod'): + style = apply_bias(dense(dlatent, fmaps=x.shape[1]*2, gain=1, **kwargs)) + style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2)) + return x * (style[:,0] + 1) + style[:,1] + +#---------------------------------------------------------------------------- +# Noise input. + +def apply_noise(x, noise_var=None, randomize_noise=True): + assert len(x.shape) == 4 # NCHW + with tf.variable_scope('Noise'): + if noise_var is None or randomize_noise: + noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) + else: + noise = tf.cast(noise_var, x.dtype) + weight = tf.get_variable('weight', shape=[x.shape[1].value], initializer=tf.initializers.zeros()) + return x + noise * tf.reshape(tf.cast(weight, x.dtype), [1, -1, 1, 1]) + +#---------------------------------------------------------------------------- +# Minibatch standard deviation. + +def minibatch_stddev_layer(x, group_size=4, num_new_features=1): + with tf.variable_scope('MinibatchStddev'): + group_size = tf.minimum(group_size, tf.shape(x)[0]) # Minibatch must be divisible by (or smaller than) group_size. + s = x.shape # [NCHW] Input shape. + y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]]) # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c. + y = tf.cast(y, tf.float32) # [GMncHW] Cast to FP32. + y -= tf.reduce_mean(y, axis=0, keepdims=True) # [GMncHW] Subtract mean over group. + y = tf.reduce_mean(tf.square(y), axis=0) # [MncHW] Calc variance over group. + y = tf.sqrt(y + 1e-8) # [MncHW] Calc stddev over group. + y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True) # [Mn111] Take average over fmaps and pixels. + y = tf.reduce_mean(y, axis=[2]) # [Mn11] Split channels into c channel groups + y = tf.cast(y, x.dtype) # [Mn11] Cast back to original data type. + y = tf.tile(y, [group_size, 1, s[2], s[3]]) # [NnHW] Replicate over group and pixels. + return tf.concat([x, y], axis=1) # [NCHW] Append as new fmap. + +#---------------------------------------------------------------------------- +# Style-based generator used in the StyleGAN paper. +# Composed of two sub-networks (G_mapping and G_synthesis) that are defined below. + +def G_style( + latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. + labels_in, # Second input: Conditioning labels [minibatch, label_size]. + truncation_psi = 0.7, # Style strength multiplier for the truncation trick. None = disable. + truncation_cutoff = 8, # Number of layers for which to apply the truncation trick. None = disable. + truncation_psi_val = None, # Value for truncation_psi to use during validation. + truncation_cutoff_val = None, # Value for truncation_cutoff to use during validation. + dlatent_avg_beta = 0.995, # Decay for tracking the moving average of W during training. None = disable. + style_mixing_prob = 0.9, # Probability of mixing styles during training. None = disable. + is_training = False, # Network is under training? Enables and disables specific features. + is_validation = False, # Network is under validation? Chooses which value to use for truncation_psi. + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + components = dnnlib.EasyDict(), # Container for sub-networks. Retained between calls. + **kwargs): # Arguments for sub-networks (G_mapping and G_synthesis). + + # Validate arguments. + assert not is_training or not is_validation + assert isinstance(components, dnnlib.EasyDict) + if is_validation: + truncation_psi = truncation_psi_val + truncation_cutoff = truncation_cutoff_val + if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): + truncation_psi = None + if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): + truncation_cutoff = None + if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): + dlatent_avg_beta = None + if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): + style_mixing_prob = None + + # Setup components. + if 'synthesis' not in components: + components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs) + num_layers = components.synthesis.input_shape[1] + dlatent_size = components.synthesis.input_shape[2] + if 'mapping' not in components: + components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) + + # Setup variables. + lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) + dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) + + # Evaluate mapping network. + dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) + + # Update moving average of W. + if dlatent_avg_beta is not None: + with tf.variable_scope('DlatentAvg'): + batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) + update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) + with tf.control_dependencies([update_op]): + dlatents = tf.identity(dlatents) + + # Perform style mixing regularization. + if style_mixing_prob is not None: + with tf.name_scope('StyleMix'): + latents2 = tf.random_normal(tf.shape(latents_in)) + dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs) + layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] + cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 + mixing_cutoff = tf.cond( + tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, + lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), + lambda: cur_layers) + dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) + + # Apply truncation trick. + if truncation_psi is not None and truncation_cutoff is not None: + with tf.variable_scope('Truncation'): + layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] + ones = np.ones(layer_idx.shape, dtype=np.float32) + coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) + dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) + + # Evaluate synthesis network. + with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]): + images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs) + return tf.identity(images_out, name='images_out') + +#---------------------------------------------------------------------------- +# Mapping network used in the StyleGAN paper. + +def G_mapping( + latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. + labels_in, # Second input: Conditioning labels [minibatch, label_size]. + latent_size = 512, # Latent vector (Z) dimensionality. + label_size = 0, # Label dimensionality, 0 if no labels. + dlatent_size = 512, # Disentangled latent (W) dimensionality. + dlatent_broadcast = None, # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size]. + mapping_layers = 8, # Number of mapping layers. + mapping_fmaps = 512, # Number of activations in the mapping layers. + mapping_lrmul = 0.01, # Learning rate multiplier for the mapping layers. + mapping_nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu'. + use_wscale = True, # Enable equalized learning rate? + normalize_latents = True, # Normalize latent vectors (Z) before feeding them to the mapping layers? + dtype = 'float32', # Data type to use for activations and outputs. + **_kwargs): # Ignore unrecognized keyword args. + + act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[mapping_nonlinearity] + + # Inputs. + latents_in.set_shape([None, latent_size]) + labels_in.set_shape([None, label_size]) + latents_in = tf.cast(latents_in, dtype) + labels_in = tf.cast(labels_in, dtype) + x = latents_in + + # Embed labels and concatenate them with latents. + if label_size: + with tf.variable_scope('LabelConcat'): + w = tf.get_variable('weight', shape=[label_size, latent_size], initializer=tf.initializers.random_normal()) + y = tf.matmul(labels_in, tf.cast(w, dtype)) + x = tf.concat([x, y], axis=1) + + # Normalize latents. + if normalize_latents: + x = pixel_norm(x) + + # Mapping layers. + for layer_idx in range(mapping_layers): + with tf.variable_scope('Dense%d' % layer_idx): + fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps + x = dense(x, fmaps=fmaps, gain=gain, use_wscale=use_wscale, lrmul=mapping_lrmul) + x = apply_bias(x, lrmul=mapping_lrmul) + x = act(x) + + # Broadcast. + if dlatent_broadcast is not None: + with tf.variable_scope('Broadcast'): + x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1]) + + # Output. + assert x.dtype == tf.as_dtype(dtype) + return tf.identity(x, name='dlatents_out') + +#---------------------------------------------------------------------------- +# Synthesis network used in the StyleGAN paper. + +def G_synthesis( + dlatents_in, # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size]. + dlatent_size = 512, # Disentangled latent (W) dimensionality. + num_channels = 3, # Number of output color channels. + resolution = 1024, # Output resolution. + fmap_base = 8192, # Overall multiplier for the number of feature maps. + fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. + fmap_max = 512, # Maximum number of feature maps in any layer. + use_styles = True, # Enable style inputs? + const_input_layer = True, # First layer is a learned constant? + use_noise = True, # Enable noise inputs? + randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. + nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu' + use_wscale = True, # Enable equalized learning rate? + use_pixel_norm = False, # Enable pixelwise feature vector normalization? + use_instance_norm = True, # Enable instance normalization? + dtype = 'float32', # Data type to use for activations and outputs. + fused_scale = 'auto', # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically. + blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering. + structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically. + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + force_clean_graph = False, # True = construct a clean graph that looks nice in TensorBoard, False = default behavior. + **_kwargs): # Ignore unrecognized keyword args. + + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2**resolution_log2 and resolution >= 4 + def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + def blur(x): return blur2d(x, blur_filter) if blur_filter else x + if is_template_graph: force_clean_graph = True + if force_clean_graph: randomize_noise = False + if structure == 'auto': structure = 'linear' if force_clean_graph else 'recursive' + act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity] + num_layers = resolution_log2 * 2 - 2 + num_styles = num_layers if use_styles else 1 + images_out = None + + # Primary inputs. + dlatents_in.set_shape([None, num_styles, dlatent_size]) + dlatents_in = tf.cast(dlatents_in, dtype) + lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype) + + # Noise inputs. + noise_inputs = [] + if use_noise: + for layer_idx in range(num_layers): + res = layer_idx // 2 + 2 + shape = [1, use_noise, 2**res, 2**res] + noise_inputs.append(tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False)) + + # Things to do at the end of each layer. + def layer_epilogue(x, layer_idx): + if use_noise: + x = apply_noise(x, noise_inputs[layer_idx], randomize_noise=randomize_noise) + x = apply_bias(x) + x = act(x) + if use_pixel_norm: + x = pixel_norm(x) + if use_instance_norm: + x = instance_norm(x) + if use_styles: + x = style_mod(x, dlatents_in[:, layer_idx], use_wscale=use_wscale) + return x + + # Early layers. + with tf.variable_scope('4x4'): + if const_input_layer: + with tf.variable_scope('Const'): + x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.ones()) + x = layer_epilogue(tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1]), 0) + else: + with tf.variable_scope('Dense'): + x = dense(dlatents_in[:, 0], fmaps=nf(1)*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN + x = layer_epilogue(tf.reshape(x, [-1, nf(1), 4, 4]), 0) + with tf.variable_scope('Conv'): + x = layer_epilogue(conv2d(x, fmaps=nf(1), kernel=3, gain=gain, use_wscale=use_wscale), 1) + + # Building blocks for remaining layers. + def block(res, x): # res = 3..resolution_log2 + with tf.variable_scope('%dx%d' % (2**res, 2**res)): + with tf.variable_scope('Conv0_up'): + x = layer_epilogue(blur(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale)), res*2-4) + with tf.variable_scope('Conv1'): + x = layer_epilogue(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale), res*2-3) + return x + def torgb(res, x): # res = 2..resolution_log2 + lod = resolution_log2 - res + with tf.variable_scope('ToRGB_lod%d' % lod): + return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale)) + + # Fixed structure: simple and efficient, but does not support progressive growing. + if structure == 'fixed': + for res in range(3, resolution_log2 + 1): + x = block(res, x) + images_out = torgb(resolution_log2, x) + + # Linear structure: simple but inefficient. + if structure == 'linear': + images_out = torgb(2, x) + for res in range(3, resolution_log2 + 1): + lod = resolution_log2 - res + x = block(res, x) + img = torgb(res, x) + images_out = upscale2d(images_out) + with tf.variable_scope('Grow_lod%d' % lod): + images_out = tflib.lerp_clip(img, images_out, lod_in - lod) + + # Recursive structure: complex but efficient. + if structure == 'recursive': + def cset(cur_lambda, new_cond, new_lambda): + return lambda: tf.cond(new_cond, new_lambda, cur_lambda) + def grow(x, res, lod): + y = block(res, x) + img = lambda: upscale2d(torgb(res, y), 2**lod) + img = cset(img, (lod_in > lod), lambda: upscale2d(tflib.lerp(torgb(res, y), upscale2d(torgb(res - 1, x)), lod_in - lod), 2**lod)) + if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1)) + return img() + images_out = grow(x, 3, resolution_log2 - 3) + + assert images_out.dtype == tf.as_dtype(dtype) + return tf.identity(images_out, name='images_out') + +#---------------------------------------------------------------------------- +# Discriminator used in the StyleGAN paper. + +def D_basic( + images_in, # First input: Images [minibatch, channel, height, width]. + labels_in, # Second input: Labels [minibatch, label_size]. + num_channels = 1, # Number of input color channels. Overridden based on dataset. + resolution = 32, # Input resolution. Overridden based on dataset. + label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. + fmap_base = 8192, # Overall multiplier for the number of feature maps. + fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. + fmap_max = 512, # Maximum number of feature maps in any layer. + nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', + use_wscale = True, # Enable equalized learning rate? + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable. + mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer. + dtype = 'float32', # Data type to use for activations and outputs. + fused_scale = 'auto', # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically. + blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering. + structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically. + is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation. + **_kwargs): # Ignore unrecognized keyword args. + + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2**resolution_log2 and resolution >= 4 + def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + def blur(x): return blur2d(x, blur_filter) if blur_filter else x + if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive' + act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity] + + images_in.set_shape([None, num_channels, resolution, resolution]) + labels_in.set_shape([None, label_size]) + images_in = tf.cast(images_in, dtype) + labels_in = tf.cast(labels_in, dtype) + lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype) + scores_out = None + + # Building blocks. + def fromrgb(x, res): # res = 2..resolution_log2 + with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)): + return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, gain=gain, use_wscale=use_wscale))) + def block(x, res): # res = 2..resolution_log2 + with tf.variable_scope('%dx%d' % (2**res, 2**res)): + if res >= 3: # 8x8 and up + with tf.variable_scope('Conv0'): + x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale))) + with tf.variable_scope('Conv1_down'): + x = act(apply_bias(conv2d_downscale2d(blur(x), fmaps=nf(res-2), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale))) + else: # 4x4 + if mbstd_group_size > 1: + x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) + with tf.variable_scope('Conv'): + x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale))) + with tf.variable_scope('Dense0'): + x = act(apply_bias(dense(x, fmaps=nf(res-2), gain=gain, use_wscale=use_wscale))) + with tf.variable_scope('Dense1'): + x = apply_bias(dense(x, fmaps=max(label_size, 1), gain=1, use_wscale=use_wscale)) + return x + + # Fixed structure: simple and efficient, but does not support progressive growing. + if structure == 'fixed': + x = fromrgb(images_in, resolution_log2) + for res in range(resolution_log2, 2, -1): + x = block(x, res) + scores_out = block(x, 2) + + # Linear structure: simple but inefficient. + if structure == 'linear': + img = images_in + x = fromrgb(img, resolution_log2) + for res in range(resolution_log2, 2, -1): + lod = resolution_log2 - res + x = block(x, res) + img = downscale2d(img) + y = fromrgb(img, res - 1) + with tf.variable_scope('Grow_lod%d' % lod): + x = tflib.lerp_clip(x, y, lod_in - lod) + scores_out = block(x, 2) + + # Recursive structure: complex but efficient. + if structure == 'recursive': + def cset(cur_lambda, new_cond, new_lambda): + return lambda: tf.cond(new_cond, new_lambda, cur_lambda) + def grow(res, lod): + x = lambda: fromrgb(downscale2d(images_in, 2**lod), res) + if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1)) + x = block(x(), res); y = lambda: x + if res > 2: y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod)) + return y() + scores_out = grow(2, resolution_log2 - 2) + + # Label conditioning from "Which Training Methods for GANs do actually Converge?" + if label_size: + with tf.variable_scope('LabelSwitch'): + scores_out = tf.reduce_sum(scores_out * labels_in, axis=1, keepdims=True) + + assert scores_out.dtype == tf.as_dtype(dtype) + scores_out = tf.identity(scores_out, name='scores_out') + return scores_out + +#---------------------------------------------------------------------------- diff --git a/models/stylegan/stylegan_tf/training/training_loop.py b/models/stylegan/stylegan_tf/training/training_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ccb45b1a0321f1d938efa6a62229ffe396dcfe --- /dev/null +++ b/models/stylegan/stylegan_tf/training/training_loop.py @@ -0,0 +1,278 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# This work is licensed under the Creative Commons Attribution-NonCommercial +# 4.0 International License. To view a copy of this license, visit +# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + +"""Main training script.""" + +import os +import numpy as np +import tensorflow as tf +import dnnlib +import dnnlib.tflib as tflib +from dnnlib.tflib.autosummary import autosummary + +import config +import train +from training import dataset +from training import misc +from metrics import metric_base + +#---------------------------------------------------------------------------- +# Just-in-time processing of training images before feeding them to the networks. + +def process_reals(x, lod, mirror_augment, drange_data, drange_net): + with tf.name_scope('ProcessReals'): + with tf.name_scope('DynamicRange'): + x = tf.cast(x, tf.float32) + x = misc.adjust_dynamic_range(x, drange_data, drange_net) + if mirror_augment: + with tf.name_scope('MirrorAugment'): + s = tf.shape(x) + mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) + mask = tf.tile(mask, [1, s[1], s[2], s[3]]) + x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) + with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. + s = tf.shape(x) + y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) + y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) + y = tf.tile(y, [1, 1, 1, 2, 1, 2]) + y = tf.reshape(y, [-1, s[1], s[2], s[3]]) + x = tflib.lerp(x, y, lod - tf.floor(lod)) + with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks. + s = tf.shape(x) + factor = tf.cast(2 ** tf.floor(lod), tf.int32) + x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) + x = tf.tile(x, [1, 1, 1, factor, 1, factor]) + x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) + return x + +#---------------------------------------------------------------------------- +# Evaluate time-varying training parameters. + +def training_schedule( + cur_nimg, + training_set, + num_gpus, + lod_initial_resolution = 4, # Image resolution used at the beginning. + lod_training_kimg = 600, # Thousands of real images to show before doubling the resolution. + lod_transition_kimg = 600, # Thousands of real images to show when fading in new layers. + minibatch_base = 16, # Maximum minibatch size, divided evenly among GPUs. + minibatch_dict = {}, # Resolution-specific overrides. + max_minibatch_per_gpu = {}, # Resolution-specific maximum minibatch size per GPU. + G_lrate_base = 0.001, # Learning rate for the generator. + G_lrate_dict = {}, # Resolution-specific overrides. + D_lrate_base = 0.001, # Learning rate for the discriminator. + D_lrate_dict = {}, # Resolution-specific overrides. + lrate_rampup_kimg = 0, # Duration of learning rate ramp-up. + tick_kimg_base = 160, # Default interval of progress snapshots. + tick_kimg_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:30, 1024:20}): # Resolution-specific overrides. + + # Initialize result dict. + s = dnnlib.EasyDict() + s.kimg = cur_nimg / 1000.0 + + # Training phase. + phase_dur = lod_training_kimg + lod_transition_kimg + phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0 + phase_kimg = s.kimg - phase_idx * phase_dur + + # Level-of-detail and resolution. + s.lod = training_set.resolution_log2 + s.lod -= np.floor(np.log2(lod_initial_resolution)) + s.lod -= phase_idx + if lod_transition_kimg > 0: + s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg + s.lod = max(s.lod, 0.0) + s.resolution = 2 ** (training_set.resolution_log2 - int(np.floor(s.lod))) + + # Minibatch size. + s.minibatch = minibatch_dict.get(s.resolution, minibatch_base) + s.minibatch -= s.minibatch % num_gpus + if s.resolution in max_minibatch_per_gpu: + s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus) + + # Learning rate. + s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base) + s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base) + if lrate_rampup_kimg > 0: + rampup = min(s.kimg / lrate_rampup_kimg, 1.0) + s.G_lrate *= rampup + s.D_lrate *= rampup + + # Other parameters. + s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base) + return s + +#---------------------------------------------------------------------------- +# Main training script. + +def training_loop( + submit_config, + G_args = {}, # Options for generator network. + D_args = {}, # Options for discriminator network. + G_opt_args = {}, # Options for generator optimizer. + D_opt_args = {}, # Options for discriminator optimizer. + G_loss_args = {}, # Options for generator loss. + D_loss_args = {}, # Options for discriminator loss. + dataset_args = {}, # Options for dataset.load_dataset(). + sched_args = {}, # Options for train.TrainingSchedule. + grid_args = {}, # Options for train.setup_snapshot_image_grid(). + metric_arg_list = [], # Options for MetricGroup. + tf_config = {}, # Options for tflib.init_tf(). + G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. + D_repeats = 1, # How many times the discriminator is trained per G iteration. + minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. + reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? + total_kimg = 15000, # Total length of the training, measured in thousands of real images. + mirror_augment = False, # Enable mirror augment? + drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. + image_snapshot_ticks = 1, # How often to export image snapshots? + network_snapshot_ticks = 10, # How often to export network snapshots? + save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? + save_weight_histograms = False, # Include weight histograms in the tfevents file? + resume_run_id = None, # Run ID or network pkl to resume training from, None = start from scratch. + resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. + resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. + resume_time = 0.0): # Assumed wallclock time at the beginning. Affects reporting. + + # Initialize dnnlib and TensorFlow. + ctx = dnnlib.RunContext(submit_config, train) + tflib.init_tf(tf_config) + + # Load training set. + training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) + + # Construct networks. + with tf.device('/gpu:0'): + if resume_run_id is not None: + network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) + print('Loading networks from "%s"...' % network_pkl) + G, D, Gs = misc.load_pkl(network_pkl) + else: + print('Constructing networks...') + G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) + D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) + Gs = G.clone('Gs') + G.print_layers(); D.print_layers() + + print('Building TensorFlow graph...') + with tf.name_scope('Inputs'), tf.device('/cpu:0'): + lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) + lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) + minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) + minibatch_split = minibatch_in // submit_config.num_gpus + Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 + + G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) + D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) + for gpu in range(submit_config.num_gpus): + with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): + G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') + D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') + lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] + reals, labels = training_set.get_minibatch_tf() + reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) + with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): + G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) + with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): + D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) + G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) + D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) + G_train_op = G_opt.apply_updates() + D_train_op = D_opt.apply_updates() + + Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) + with tf.device('/gpu:0'): + try: + peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() + except tf.errors.NotFoundError: + peak_gpu_mem_op = tf.constant(0) + + print('Setting up snapshot image grid...') + grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) + sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) + grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) + + print('Setting up run dir...') + misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) + misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) + summary_log = tf.summary.FileWriter(submit_config.run_dir) + if save_tf_graph: + summary_log.add_graph(tf.get_default_graph()) + if save_weight_histograms: + G.setup_weight_histograms(); D.setup_weight_histograms() + metrics = metric_base.MetricGroup(metric_arg_list) + + print('Training...\n') + ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) + maintenance_time = ctx.get_last_update_interval() + cur_nimg = int(resume_kimg * 1000) + cur_tick = 0 + tick_start_nimg = cur_nimg + prev_lod = -1.0 + while cur_nimg < total_kimg * 1000: + if ctx.should_stop(): break + + # Choose training parameters and configure training ops. + sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) + training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) + if reset_opt_for_new_lod: + if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): + G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() + prev_lod = sched.lod + + # Run training ops. + for _mb_repeat in range(minibatch_repeats): + for _D_repeat in range(D_repeats): + tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) + cur_nimg += sched.minibatch + tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) + + # Perform maintenance tasks once per tick. + done = (cur_nimg >= total_kimg * 1000) + if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: + cur_tick += 1 + tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 + tick_start_nimg = cur_nimg + tick_time = ctx.get_time_since_last_update() + total_time = ctx.get_time_since_start() + resume_time + + # Report progress. + print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( + autosummary('Progress/tick', cur_tick), + autosummary('Progress/kimg', cur_nimg / 1000.0), + autosummary('Progress/lod', sched.lod), + autosummary('Progress/minibatch', sched.minibatch), + dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), + autosummary('Timing/sec_per_tick', tick_time), + autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), + autosummary('Timing/maintenance_sec', maintenance_time), + autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) + autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) + autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) + + # Save snapshots. + if cur_tick % image_snapshot_ticks == 0 or done: + grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) + misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) + if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: + pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) + misc.save_pkl((G, D, Gs), pkl) + metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) + + # Update summaries and RunContext. + metrics.update_autosummaries() + tflib.autosummary.save_summaries(summary_log, cur_nimg) + ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) + maintenance_time = ctx.get_last_update_interval() - tick_time + + # Write final results. + misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) + summary_log.close() + + ctx.close() + +#---------------------------------------------------------------------------- diff --git a/models/stylegan2/__init__.py b/models/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87739d5c18fe051149018f275983ebf6380c8b54 --- /dev/null +++ b/models/stylegan2/__init__.py @@ -0,0 +1,16 @@ +import sys +import os +import shutil +import glob +import platform +from pathlib import Path + +current_path = os.getcwd() + +module_path = Path(__file__).parent / 'stylegan2-pytorch' +sys.path.append(str(module_path.resolve())) +os.chdir(module_path) + +from model import Generator + +os.chdir(current_path) \ No newline at end of file diff --git a/models/stylegan2/stylegan2-pytorch/.gitignore b/models/stylegan2/stylegan2-pytorch/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b8e449f3ff8a4951e8122cefa463ce506b590246 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +wandb/ +*.lmdb/ +*.pkl diff --git a/models/stylegan2/stylegan2-pytorch/LICENSE b/models/stylegan2/stylegan2-pytorch/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..915ca760bc639695e152e784d9dc2dbf71369b67 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/models/stylegan2/stylegan2-pytorch/LICENSE-FID b/models/stylegan2/stylegan2-pytorch/LICENSE-FID new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/LICENSE-FID @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/models/stylegan2/stylegan2-pytorch/LICENSE-LPIPS b/models/stylegan2/stylegan2-pytorch/LICENSE-LPIPS new file mode 100644 index 0000000000000000000000000000000000000000..e269c6bdc77eecf327fd72b156b7ec3b6434066c --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/LICENSE-LPIPS @@ -0,0 +1,24 @@ +Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/models/stylegan2/stylegan2-pytorch/LICENSE-NVIDIA b/models/stylegan2/stylegan2-pytorch/LICENSE-NVIDIA new file mode 100644 index 0000000000000000000000000000000000000000..288fb3247529fc0d19ee2040c29adc65886d9426 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/LICENSE-NVIDIA @@ -0,0 +1,101 @@ +Copyright (c) 2019, NVIDIA Corporation. All rights reserved. + + +Nvidia Source Code License-NC + +======================================================================= + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +"Nvidia Processors" means any central processing unit (CPU), graphics +processing unit (GPU), field-programmable gate array (FPGA), +application-specific integrated circuit (ASIC) or any combination +thereof designed, made, sold, or provided by Nvidia or its affiliates. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. The Work or + derivative works thereof may be used or intended for use by Nvidia + or its affiliates commercially or non-commercially. As used herein, + "non-commercially" means for research or evaluation purposes only. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grants in Sections 2.1 and 2.2) will + terminate immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor's or its affiliates' names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grants in Sections 2.1 and + 2.2) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +======================================================================= diff --git a/models/stylegan2/stylegan2-pytorch/README.md b/models/stylegan2/stylegan2-pytorch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..325c7b4fe1ee3e4b72f48c0849b0c4a7136f368d --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/README.md @@ -0,0 +1,83 @@ +# StyleGAN 2 in PyTorch + +Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch + +## Notice + +I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care. + +## Requirements + +I have tested on: + +* PyTorch 1.3.1 +* CUDA 10.1/10.2 + +## Usage + +First create lmdb datasets: + +> python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH + +This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later. + +Then you can train model in distributed settings + +> python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH + +train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script. + +### Convert weight from official checkpoints + +You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints. + +Next, create a conda environment with TF-GPU and Torch-CPU (using GPU for both results in CUDA version mismatches):
+`conda create -n tf_torch python=3.7 requests tensorflow-gpu=1.14 cudatoolkit=10.0 numpy=1.14 pytorch=1.6 torchvision cpuonly -c pytorch` + +For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this: + +> python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl + +This will create converted stylegan2-ffhq-config-f.pt file. + +If using GCC, you might have to set `-D_GLIBCXX_USE_CXX11_ABI=1` in `~/stylegan2/dnnlib/tflib/custom_ops.py`. + +### Generate samples + +> python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT + +You should change your size (--size 256 for example) if you train with another dimension. + +### Project images to latent spaces + +> python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ... + +## Pretrained Checkpoints + +[Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO) + +I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences. + +## Samples + +![Sample with truncation](doc/sample.png) + +At 110,000 iterations. (trained on 3.52M images) + +### Samples from converted weights + +![Sample from FFHQ](doc/stylegan2-ffhq-config-f.png) + +Sample from FFHQ (1024px) + +![Sample from LSUN Church](doc/stylegan2-church-config-f.png) + +Sample from LSUN Church (256px) + +## License + +Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2 + +Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity + +To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid diff --git a/models/stylegan2/stylegan2-pytorch/calc_inception.py b/models/stylegan2/stylegan2-pytorch/calc_inception.py new file mode 100644 index 0000000000000000000000000000000000000000..5daa531475c377a73ffa256bdf84bb662e144215 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/calc_inception.py @@ -0,0 +1,116 @@ +import argparse +import pickle +import os + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.models import inception_v3, Inception3 +import numpy as np +from tqdm import tqdm + +from inception import InceptionV3 +from dataset import MultiResolutionDataset + + +class Inception3Feature(Inception3): + def forward(self, x): + if x.shape[2] != 299 or x.shape[3] != 299: + x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) + + x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 + x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 + x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 + x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 + + x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 + x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 + x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 + + x = self.Mixed_5b(x) # 35 x 35 x 192 + x = self.Mixed_5c(x) # 35 x 35 x 256 + x = self.Mixed_5d(x) # 35 x 35 x 288 + + x = self.Mixed_6a(x) # 35 x 35 x 288 + x = self.Mixed_6b(x) # 17 x 17 x 768 + x = self.Mixed_6c(x) # 17 x 17 x 768 + x = self.Mixed_6d(x) # 17 x 17 x 768 + x = self.Mixed_6e(x) # 17 x 17 x 768 + + x = self.Mixed_7a(x) # 17 x 17 x 768 + x = self.Mixed_7b(x) # 8 x 8 x 1280 + x = self.Mixed_7c(x) # 8 x 8 x 2048 + + x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 + + return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 + + +def load_patched_inception_v3(): + # inception = inception_v3(pretrained=True) + # inception_feat = Inception3Feature() + # inception_feat.load_state_dict(inception.state_dict()) + inception_feat = InceptionV3([3], normalize_input=False) + + return inception_feat + + +@torch.no_grad() +def extract_features(loader, inception, device): + pbar = tqdm(loader) + + feature_list = [] + + for img in pbar: + img = img.to(device) + feature = inception(img)[0].view(img.shape[0], -1) + feature_list.append(feature.to('cpu')) + + features = torch.cat(feature_list, 0) + + return features + + +if __name__ == '__main__': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + parser = argparse.ArgumentParser( + description='Calculate Inception v3 features for datasets' + ) + parser.add_argument('--size', type=int, default=256) + parser.add_argument('--batch', default=64, type=int, help='batch size') + parser.add_argument('--n_sample', type=int, default=50000) + parser.add_argument('--flip', action='store_true') + parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') + + args = parser.parse_args() + + inception = load_patched_inception_v3() + inception = nn.DataParallel(inception).eval().to(device) + + transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) + loader = DataLoader(dset, batch_size=args.batch, num_workers=4) + + features = extract_features(loader, inception, device).numpy() + + features = features[: args.n_sample] + + print(f'extracted {features.shape[0]} features') + + mean = np.mean(features, 0) + cov = np.cov(features, rowvar=False) + + name = os.path.splitext(os.path.basename(args.path))[0] + + with open(f'inception_{name}.pkl', 'wb') as f: + pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) diff --git a/models/stylegan2/stylegan2-pytorch/checkpoint/.gitignore b/models/stylegan2/stylegan2-pytorch/checkpoint/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4b6ebe5ff71ddf0402c7083d1325c4d4bcf1b045 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/checkpoint/.gitignore @@ -0,0 +1 @@ +*.pt diff --git a/models/stylegan2/stylegan2-pytorch/convert_weight.py b/models/stylegan2/stylegan2-pytorch/convert_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..09b0a02dc48e3a8736f65bfe337a8c59aa206029 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/convert_weight.py @@ -0,0 +1,283 @@ +import argparse +import os +import sys +import pickle +import math + +import torch +import numpy as np +from torchvision import utils + +from model import Generator, Discriminator + + +def convert_modconv(vars, source_name, target_name, flip=False): + weight = vars[source_name + '/weight'].value().eval() + mod_weight = vars[source_name + '/mod_weight'].value().eval() + mod_bias = vars[source_name + '/mod_bias'].value().eval() + noise = vars[source_name + '/noise_strength'].value().eval() + bias = vars[source_name + '/bias'].value().eval() + + dic = { + 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), + 'conv.modulation.weight': mod_weight.transpose((1, 0)), + 'conv.modulation.bias': mod_bias + 1, + 'noise.weight': np.array([noise]), + 'activate.bias': bias, + } + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + '.' + k] = torch.from_numpy(v) + + if flip: + dic_torch[target_name + '.conv.weight'] = torch.flip( + dic_torch[target_name + '.conv.weight'], [3, 4] + ) + + return dic_torch + + +def convert_conv(vars, source_name, target_name, bias=True, start=0): + weight = vars[source_name + '/weight'].value().eval() + + dic = {'weight': weight.transpose((3, 2, 0, 1))} + + if bias: + dic['bias'] = vars[source_name + '/bias'].value().eval() + + dic_torch = {} + + dic_torch[target_name + f'.{start}.weight'] = torch.from_numpy(dic['weight']) + + if bias: + dic_torch[target_name + f'.{start + 1}.bias'] = torch.from_numpy(dic['bias']) + + return dic_torch + + +def convert_torgb(vars, source_name, target_name): + weight = vars[source_name + '/weight'].value().eval() + mod_weight = vars[source_name + '/mod_weight'].value().eval() + mod_bias = vars[source_name + '/mod_bias'].value().eval() + bias = vars[source_name + '/bias'].value().eval() + + dic = { + 'conv.weight': np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), + 'conv.modulation.weight': mod_weight.transpose((1, 0)), + 'conv.modulation.bias': mod_bias + 1, + 'bias': bias.reshape((1, 3, 1, 1)), + } + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + '.' + k] = torch.from_numpy(v) + + return dic_torch + + +def convert_dense(vars, source_name, target_name): + weight = vars[source_name + '/weight'].value().eval() + bias = vars[source_name + '/bias'].value().eval() + + dic = {'weight': weight.transpose((1, 0)), 'bias': bias} + + dic_torch = {} + + for k, v in dic.items(): + dic_torch[target_name + '.' + k] = torch.from_numpy(v) + + return dic_torch + + +def update(state_dict, new): + for k, v in new.items(): + if k not in state_dict: + raise KeyError(k + ' is not found') + + if v.shape != state_dict[k].shape: + raise ValueError(f'Shape mismatch: {v.shape} vs {state_dict[k].shape}') + + state_dict[k] = v + + +def discriminator_fill_statedict(statedict, vars, size): + log_size = int(math.log(size, 2)) + + update(statedict, convert_conv(vars, f'{size}x{size}/FromRGB', 'convs.0')) + + conv_i = 1 + + for i in range(log_size - 2, 0, -1): + reso = 4 * 2 ** i + update( + statedict, + convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1'), + ) + update( + statedict, + convert_conv( + vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1 + ), + ) + update( + statedict, + convert_conv( + vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False + ), + ) + conv_i += 1 + + update(statedict, convert_conv(vars, f'4x4/Conv', 'final_conv')) + update(statedict, convert_dense(vars, f'4x4/Dense0', 'final_linear.0')) + update(statedict, convert_dense(vars, f'Output', 'final_linear.1')) + + return statedict + + +def fill_statedict(state_dict, vars, size): + log_size = int(math.log(size, 2)) + + for i in range(8): + update(state_dict, convert_dense(vars, f'G_mapping/Dense{i}', f'style.{i + 1}')) + + update( + state_dict, + { + 'input.input': torch.from_numpy( + vars['G_synthesis/4x4/Const/const'].value().eval() + ) + }, + ) + + update(state_dict, convert_torgb(vars, 'G_synthesis/4x4/ToRGB', 'to_rgb1')) + + for i in range(log_size - 2): + reso = 4 * 2 ** (i + 1) + update( + state_dict, + convert_torgb(vars, f'G_synthesis/{reso}x{reso}/ToRGB', f'to_rgbs.{i}'), + ) + + update(state_dict, convert_modconv(vars, 'G_synthesis/4x4/Conv', 'conv1')) + + conv_i = 0 + + for i in range(log_size - 2): + reso = 4 * 2 ** (i + 1) + update( + state_dict, + convert_modconv( + vars, + f'G_synthesis/{reso}x{reso}/Conv0_up', + f'convs.{conv_i}', + flip=True, + ), + ) + update( + state_dict, + convert_modconv( + vars, f'G_synthesis/{reso}x{reso}/Conv1', f'convs.{conv_i + 1}' + ), + ) + conv_i += 2 + + for i in range(0, (log_size - 2) * 2 + 1): + update( + state_dict, + { + f'noises.noise_{i}': torch.from_numpy( + vars[f'G_synthesis/noise{i}'].value().eval() + ) + }, + ) + + return state_dict + + +if __name__ == '__main__': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print('Using PyTorch device', device) + + parser = argparse.ArgumentParser() + parser.add_argument('--repo', type=str, required=True) + parser.add_argument('--gen', action='store_true') + parser.add_argument('--disc', action='store_true') + parser.add_argument('--channel_multiplier', type=int, default=2) + parser.add_argument('path', metavar='PATH') + + args = parser.parse_args() + + sys.path.append(args.repo) + + import dnnlib + from dnnlib import tflib + + tflib.init_tf() + + with open(args.path, 'rb') as f: + generator, discriminator, g_ema = pickle.load(f) + + size = g_ema.output_shape[2] + + g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) + state_dict = g.state_dict() + state_dict = fill_statedict(state_dict, g_ema.vars, size) + + g.load_state_dict(state_dict) + + latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval()) + + ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg} + + if args.gen: + g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) + g_train_state = g_train.state_dict() + g_train_state = fill_statedict(g_train_state, generator.vars, size) + ckpt['g'] = g_train_state + + if args.disc: + disc = Discriminator(size, channel_multiplier=args.channel_multiplier) + d_state = disc.state_dict() + d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) + ckpt['d'] = d_state + + name = os.path.splitext(os.path.basename(args.path))[0] + outpath = os.path.join(os.getcwd(), f'{name}.pt') + print('Saving', outpath) + try: + torch.save(ckpt, outpath, _use_new_zipfile_serialization=False) + except TypeError: + torch.save(ckpt, outpath) + + + print('Generating TF-Torch comparison images') + batch_size = {256: 8, 512: 4, 1024: 2} + n_sample = batch_size.get(size, 4) + + g = g.to(device) + + z = np.random.RandomState(0).randn(n_sample, 512).astype('float32') + + with torch.no_grad(): + img_pt, _ = g( + [torch.from_numpy(z).to(device)], + truncation=0.5, + truncation_latent=latent_avg.to(device), + ) + + img_tf = g_ema.run(z, None, randomize_noise=False) + img_tf = torch.from_numpy(img_tf).to(device) + + img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( + 0.0, 1.0 + ) + + img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) + utils.save_image( + img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1) + ) + print('Done') + diff --git a/models/stylegan2/stylegan2-pytorch/dataset.py b/models/stylegan2/stylegan2-pytorch/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7713ea2f8bc94d202d2dfbe830af3cb96b1e803d --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/dataset.py @@ -0,0 +1,40 @@ +from io import BytesIO + +import lmdb +from PIL import Image +from torch.utils.data import Dataset + + +class MultiResolutionDataset(Dataset): + def __init__(self, path, transform, resolution=256): + self.env = lmdb.open( + path, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + + if not self.env: + raise IOError('Cannot open lmdb dataset', path) + + with self.env.begin(write=False) as txn: + self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) + + self.resolution = resolution + self.transform = transform + + def __len__(self): + return self.length + + def __getitem__(self, index): + with self.env.begin(write=False) as txn: + key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') + img_bytes = txn.get(key) + + buffer = BytesIO(img_bytes) + img = Image.open(buffer) + img = self.transform(img) + + return img diff --git a/models/stylegan2/stylegan2-pytorch/distributed.py b/models/stylegan2/stylegan2-pytorch/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..51fa243257ef302e2015d5ff36ac531b86a9a0ce --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/distributed.py @@ -0,0 +1,126 @@ +import math +import pickle + +import torch +from torch import distributed as dist +from torch.utils.data.sampler import Sampler + + +def get_rank(): + if not dist.is_available(): + return 0 + + if not dist.is_initialized(): + return 0 + + return dist.get_rank() + + +def synchronize(): + if not dist.is_available(): + return + + if not dist.is_initialized(): + return + + world_size = dist.get_world_size() + + if world_size == 1: + return + + dist.barrier() + + +def get_world_size(): + if not dist.is_available(): + return 1 + + if not dist.is_initialized(): + return 1 + + return dist.get_world_size() + + +def reduce_sum(tensor): + if not dist.is_available(): + return tensor + + if not dist.is_initialized(): + return tensor + + tensor = tensor.clone() + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + return tensor + + +def gather_grad(params): + world_size = get_world_size() + + if world_size == 1: + return + + for param in params: + if param.grad is not None: + dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) + param.grad.data.div_(world_size) + + +def all_gather(data): + world_size = get_world_size() + + if world_size == 1: + return [data] + + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to('cuda') + + local_size = torch.IntTensor([tensor.numel()]).to('cuda') + size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) + + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') + tensor = torch.cat((tensor, padding), 0) + + dist.all_gather(tensor_list, tensor) + + data_list = [] + + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_loss_dict(loss_dict): + world_size = get_world_size() + + if world_size < 2: + return loss_dict + + with torch.no_grad(): + keys = [] + losses = [] + + for k in sorted(loss_dict.keys()): + keys.append(k) + losses.append(loss_dict[k]) + + losses = torch.stack(losses, 0) + dist.reduce(losses, dst=0) + + if dist.get_rank() == 0: + losses /= world_size + + reduced_losses = {k: v for k, v in zip(keys, losses)} + + return reduced_losses diff --git a/models/stylegan2/stylegan2-pytorch/fid.py b/models/stylegan2/stylegan2-pytorch/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..c05eeda26b3a5ae5be060c158fc7a74d4ccbfb5f --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/fid.py @@ -0,0 +1,107 @@ +import argparse +import pickle + +import torch +from torch import nn +import numpy as np +from scipy import linalg +from tqdm import tqdm + +from model import Generator +from calc_inception import load_patched_inception_v3 + + +@torch.no_grad() +def extract_feature_from_samples( + generator, inception, truncation, truncation_latent, batch_size, n_sample, device +): + n_batch = n_sample // batch_size + resid = n_sample - (n_batch * batch_size) + batch_sizes = [batch_size] * n_batch + [resid] + features = [] + + for batch in tqdm(batch_sizes): + latent = torch.randn(batch, 512, device=device) + img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) + feat = inception(img)[0].view(img.shape[0], -1) + features.append(feat.to('cpu')) + + features = torch.cat(features, 0) + + return features + + +def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): + cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) + + if not np.isfinite(cov_sqrt).all(): + print('product of cov matrices is singular') + offset = np.eye(sample_cov.shape[0]) * eps + cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) + + if np.iscomplexobj(cov_sqrt): + if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): + m = np.max(np.abs(cov_sqrt.imag)) + + raise ValueError(f'Imaginary component {m}') + + cov_sqrt = cov_sqrt.real + + mean_diff = sample_mean - real_mean + mean_norm = mean_diff @ mean_diff + + trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) + + fid = mean_norm + trace + + return fid + + +if __name__ == '__main__': + device = 'cuda' + + parser = argparse.ArgumentParser() + + parser.add_argument('--truncation', type=float, default=1) + parser.add_argument('--truncation_mean', type=int, default=4096) + parser.add_argument('--batch', type=int, default=64) + parser.add_argument('--n_sample', type=int, default=50000) + parser.add_argument('--size', type=int, default=256) + parser.add_argument('--inception', type=str, default=None, required=True) + parser.add_argument('ckpt', metavar='CHECKPOINT') + + args = parser.parse_args() + + ckpt = torch.load(args.ckpt) + + g = Generator(args.size, 512, 8).to(device) + g.load_state_dict(ckpt['g_ema']) + g = nn.DataParallel(g) + g.eval() + + if args.truncation < 1: + with torch.no_grad(): + mean_latent = g.mean_latent(args.truncation_mean) + + else: + mean_latent = None + + inception = nn.DataParallel(load_patched_inception_v3()).to(device) + inception.eval() + + features = extract_feature_from_samples( + g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device + ).numpy() + print(f'extracted {features.shape[0]} features') + + sample_mean = np.mean(features, 0) + sample_cov = np.cov(features, rowvar=False) + + with open(args.inception, 'rb') as f: + embeds = pickle.load(f) + real_mean = embeds['mean'] + real_cov = embeds['cov'] + + fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) + + print('fid:', fid) diff --git a/models/stylegan2/stylegan2-pytorch/generate.py b/models/stylegan2/stylegan2-pytorch/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..4255c8cb0a16817b3f4d60783456bfa5cd15d018 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/generate.py @@ -0,0 +1,55 @@ +import argparse + +import torch +from torchvision import utils +from model import Generator +from tqdm import tqdm +def generate(args, g_ema, device, mean_latent): + + with torch.no_grad(): + g_ema.eval() + for i in tqdm(range(args.pics)): + sample_z = torch.randn(args.sample, args.latent, device=device) + + sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent) + + utils.save_image( + sample, + f'sample/{str(i).zfill(6)}.png', + nrow=1, + normalize=True, + range=(-1, 1), + ) + +if __name__ == '__main__': + device = 'cuda' + + parser = argparse.ArgumentParser() + + parser.add_argument('--size', type=int, default=1024) + parser.add_argument('--sample', type=int, default=1) + parser.add_argument('--pics', type=int, default=20) + parser.add_argument('--truncation', type=float, default=1) + parser.add_argument('--truncation_mean', type=int, default=4096) + parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt") + parser.add_argument('--channel_multiplier', type=int, default=2) + + args = parser.parse_args() + + args.latent = 512 + args.n_mlp = 8 + + g_ema = Generator( + args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier + ).to(device) + checkpoint = torch.load(args.ckpt) + + g_ema.load_state_dict(checkpoint['g_ema']) + + if args.truncation < 1: + with torch.no_grad(): + mean_latent = g_ema.mean_latent(args.truncation_mean) + else: + mean_latent = None + + generate(args, g_ema, device, mean_latent) diff --git a/models/stylegan2/stylegan2-pytorch/inception.py b/models/stylegan2/stylegan2-pytorch/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..f3afed8123e595f65c1333dea7151e653a836e2b --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/inception.py @@ -0,0 +1,310 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = models.inception_v3(num_classes=1008, + aux_logits=False, + pretrained=False) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/models/stylegan2/stylegan2-pytorch/lpips/__init__.py b/models/stylegan2/stylegan2-pytorch/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f86b7ee229b333a64f16d0091e988492f99c58 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/lpips/__init__.py @@ -0,0 +1,160 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from skimage.measure import compare_ssim +import torch +from torch.autograd import Variable + +from lpips import dist_model + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) + # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss + super(PerceptualLoss, self).__init__() + print('Setting up Perceptual loss...') + self.use_gpu = use_gpu + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model = dist_model.DistModel() + self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) + print('...[%s] initialized'%self.model.name()) + print('...Done') + + def forward(self, pred, target, normalize=False): + """ + Pred and target are Variables. + If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] + If normalize is False, assumes the images are already between [-1,+1] + + Inputs pred and target are Nx3xHxW + Output pytorch Variable N long + """ + + if normalize: + target = 2 * target - 1 + pred = 2 * pred - 1 + + return self.model.forward(target, pred) + +def normalize_tensor(in_feat,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) + return in_feat/(norm_factor+eps) + +def l2(p0, p1, range=255.): + return .5*np.mean((p0 / range - p1 / range)**2) + +def psnr(p0, p1, peak=255.): + return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) + +def dssim(p0, p1, range=255.): + return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. + +def rgb2lab(in_img,mean_cent=False): + from skimage import color + img_lab = color.rgb2lab(in_img) + if(mean_cent): + img_lab[:,:,0] = img_lab[:,:,0]-50 + return img_lab + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if(mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + if(to_norm and not mc_only): + img_lab[:,:,0] = img_lab[:,:,0]-50 + img_lab = img_lab/100. + + return np2tensor(img_lab) + +def tensorlab2tensor(lab_tensor,return_inbnd=False): + from skimage import color + import warnings + warnings.filterwarnings("ignore") + + lab = tensor2np(lab_tensor)*100. + lab[:,:,0] = lab[:,:,0]+50 + + rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) + if(return_inbnd): + # convert back to lab, see if we match + lab_back = color.rgb2lab(rgb_back.astype('uint8')) + mask = 1.*np.isclose(lab_back,lab,atol=2.) + mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) + return (im2tensor(rgb_back),mask) + else: + return im2tensor(rgb_back) + +def rgb2lab(input): + from skimage import color + return color.rgb2lab(input / 255.) + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + +def tensor2vec(vector_tensor): + return vector_tensor.data.cpu().numpy()[:, :, 0, 0] + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): +# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): +# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) diff --git a/models/stylegan2/stylegan2-pytorch/lpips/base_model.py b/models/stylegan2/stylegan2-pytorch/lpips/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8de1d16f0c7fa52d8067139abc6e769e96d0a6a1 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/lpips/base_model.py @@ -0,0 +1,58 @@ +import os +import numpy as np +import torch +from torch.autograd import Variable +from pdb import set_trace as st +from IPython import embed + +class BaseModel(): + def __init__(self): + pass; + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True, gpu_ids=[0]): + self.use_gpu = use_gpu + self.gpu_ids = gpu_ids + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print('Loading network from %s'%save_path) + network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'),flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') diff --git a/models/stylegan2/stylegan2-pytorch/lpips/dist_model.py b/models/stylegan2/stylegan2-pytorch/lpips/dist_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff0aa4ca6e4b217954c167787eaac1ca1f8e304 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/lpips/dist_model.py @@ -0,0 +1,284 @@ + +from __future__ import absolute_import + +import sys +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +from .base_model import BaseModel +from scipy.ndimage import zoom +import fractions +import functools +import skimage.transform +from tqdm import tqdm + +from IPython import embed + +from . import networks_basic as networks +import lpips as util + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). + spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. + spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + gpu_ids - int array - [0] by default, gpus to use + ''' + BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.gpu_ids = gpu_ids + self.model_name = '%s [%s]'%(model,net) + + if(self.model == 'net-lin'): # pretrained net + linear layer + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = {} + if not use_gpu: + kw['map_location'] = 'cpu' + if(model_path is None): + import inspect + model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) + + if(not is_train): + print('Loading model from: %s'%model_path) + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif(self.model=='net'): # pretrained network + self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif(self.model in ['L2','l2']): + self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif(self.model in ['DSSIM','dssim','SSIM','ssim']): + self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = networks.BCERankingLoss() + self.parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + if(use_gpu): + self.net.to(gpu_ids[0]) + self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + if(self.is_train): + self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if(printNet): + print('---------- Networks initialized -------------') + networks.print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net.forward(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if(hasattr(module, 'weight') and module.kernel_size==(1,1)): + module.weight.data = torch.clamp(module.weight.data,min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + if(self.use_gpu): + self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + self.var_ref = Variable(self.input_ref,requires_grad=True) + self.var_p0 = Variable(self.input_p0,requires_grad=True) + self.var_p1 = Variable(self.input_p1,requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + self.d0 = self.forward(self.var_ref, self.var_p0) + self.d1 = self.forward(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) + + self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self,d0,d1,judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) + self.old_lr = lr + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() + d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() + gts+=data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + +# Wrapper that gives name to tensor +class NamedTensor(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + +# Give each style a unique name +class StridedStyle(nn.ModuleList): + def __init__(self, n_latents): + super().__init__([NamedTensor() for _ in range(n_latents)]) + self.n_latents = n_latents + + def forward(self, x): + # x already strided + styles = [self[i](x[:, i, :]) for i in range(self.n_latents)] + return torch.stack(styles, dim=1) + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + self.strided_style = StridedStyle(self.n_latent) + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_w=False, + noise=None, + randomize_noise=True, + ): + if not input_is_w: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) == 1: + # One global latent + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + elif len(styles) == 2: + # Latent mixing with two latents + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = self.strided_style(torch.cat([latent, latent2], 1)) + else: + # One latent per layer + assert len(styles) == self.n_latent, f'Expected {self.n_latents} latents, got {len(styles)}' + styles = torch.stack(styles, dim=1) # [N, 18, 512] + latent = self.strided_style(styles) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out + diff --git a/models/stylegan2/stylegan2-pytorch/non_leaking.py b/models/stylegan2/stylegan2-pytorch/non_leaking.py new file mode 100644 index 0000000000000000000000000000000000000000..4e044f98e836ae2c011ea91246b304d5ab1a1422 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/non_leaking.py @@ -0,0 +1,137 @@ +import math + +import torch +from torch.nn import functional as F + + +def translate_mat(t_x, t_y): + batch = t_x.shape[0] + + mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) + translate = torch.stack((t_x, t_y), 1) + mat[:, :2, 2] = translate + + return mat + + +def rotate_mat(theta): + batch = theta.shape[0] + + mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) + sin_t = torch.sin(theta) + cos_t = torch.cos(theta) + rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) + mat[:, :2, :2] = rot + + return mat + + +def scale_mat(s_x, s_y): + batch = s_x.shape[0] + + mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) + mat[:, 0, 0] = s_x + mat[:, 1, 1] = s_y + + return mat + + +def lognormal_sample(size, mean=0, std=1): + return torch.empty(size).log_normal_(mean=mean, std=std) + + +def category_sample(size, categories): + category = torch.tensor(categories) + sample = torch.randint(high=len(categories), size=(size,)) + + return category[sample] + + +def uniform_sample(size, low, high): + return torch.empty(size).uniform_(low, high) + + +def normal_sample(size, mean=0, std=1): + return torch.empty(size).normal_(mean, std) + + +def bernoulli_sample(size, p): + return torch.empty(size).bernoulli_(p) + + +def random_affine_apply(p, transform, prev, eye): + size = transform.shape[0] + select = bernoulli_sample(size, p).view(size, 1, 1) + select_transform = select * transform + (1 - select) * eye + + return select_transform @ prev + + +def sample_affine(p, size, height, width): + G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1) + eye = G + + # flip + param = category_sample(size, (0, 1)) + Gc = scale_mat(1 - 2.0 * param, torch.ones(size)) + G = random_affine_apply(p, Gc, G, eye) + # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') + + # 90 rotate + param = category_sample(size, (0, 3)) + Gc = rotate_mat(-math.pi / 2 * param) + G = random_affine_apply(p, Gc, G, eye) + # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') + + # integer translate + param = uniform_sample(size, -0.125, 0.125) + param_height = torch.round(param * height) / height + param_width = torch.round(param * width) / width + Gc = translate_mat(param_width, param_height) + G = random_affine_apply(p, Gc, G, eye) + # print('integer translate', G, translate_mat(param_width, param_height), sep='\n') + + # isotropic scale + param = lognormal_sample(size, std=0.2 * math.log(2)) + Gc = scale_mat(param, param) + G = random_affine_apply(p, Gc, G, eye) + # print('isotropic scale', G, scale_mat(param, param), sep='\n') + + p_rot = 1 - math.sqrt(1 - p) + + # pre-rotate + param = uniform_sample(size, -math.pi, math.pi) + Gc = rotate_mat(-param) + G = random_affine_apply(p_rot, Gc, G, eye) + # print('pre-rotate', G, rotate_mat(-param), sep='\n') + + # anisotropic scale + param = lognormal_sample(size, std=0.2 * math.log(2)) + Gc = scale_mat(param, 1 / param) + G = random_affine_apply(p, Gc, G, eye) + # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') + + # post-rotate + param = uniform_sample(size, -math.pi, math.pi) + Gc = rotate_mat(-param) + G = random_affine_apply(p_rot, Gc, G, eye) + # print('post-rotate', G, rotate_mat(-param), sep='\n') + + # fractional translate + param = normal_sample(size, std=0.125) + Gc = translate_mat(param, param) + G = random_affine_apply(p, Gc, G, eye) + # print('fractional translate', G, translate_mat(param, param), sep='\n') + + return G + + +def apply_affine(img, G): + grid = F.affine_grid( + torch.inverse(G).to(img)[:, :2, :], img.shape, align_corners=False + ) + img_affine = F.grid_sample( + img, grid, mode="bilinear", align_corners=False, padding_mode="reflection" + ) + + return img_affine diff --git a/models/stylegan2/stylegan2-pytorch/op/__init__.py b/models/stylegan2/stylegan2-pytorch/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/stylegan2/stylegan2-pytorch/op/fused_act.py b/models/stylegan2/stylegan2-pytorch/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3d464ae656920c6875bc877281cadb2eaa4105 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/fused_act.py @@ -0,0 +1,92 @@ +import os +import platform + +import torch +from torch import nn +from torch.autograd import Function +import torch.nn.functional as F +from torch.utils.cpp_extension import load + +use_fallback = False + +# Try loading precompiled, otherwise use native fallback +try: + import fused +except ModuleNotFoundError as e: + print('StyleGAN2: Optimized CUDA op FusedLeakyReLU not available, using native PyTorch fallback.') + use_fallback = True + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + if use_fallback or input.device.type == 'cpu': + return scale * F.leaky_relu( + input + bias.view((1, -1)+(1,)*(input.ndim-2)), negative_slope=negative_slope + ) + else: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/models/stylegan2/stylegan2-pytorch/op/fused_bias_act.cpp b/models/stylegan2/stylegan2-pytorch/op/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/models/stylegan2/stylegan2-pytorch/op/fused_bias_act_kernel.cu b/models/stylegan2/stylegan2-pytorch/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/models/stylegan2/stylegan2-pytorch/op/setup.py b/models/stylegan2/stylegan2-pytorch/op/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..5b926d450579990c8f09b93cbc5ae4c06128ef8d --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/setup.py @@ -0,0 +1,33 @@ +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +from pathlib import Path + +# Usage: +# python setup.py install (or python setup.py bdist_wheel) +# NB: Windows: run from VS2017 x64 Native Tool Command Prompt + +rootdir = (Path(__file__).parent / '..' / 'op').resolve() + +setup( + name='upfirdn2d', + ext_modules=[ + CUDAExtension('upfirdn2d_op', + [str(rootdir / 'upfirdn2d.cpp'), str(rootdir / 'upfirdn2d_kernel.cu')], + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) + +setup( + name='fused', + ext_modules=[ + CUDAExtension('fused', + [str(rootdir / 'fused_bias_act.cpp'), str(rootdir / 'fused_bias_act_kernel.cu')], + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) \ No newline at end of file diff --git a/models/stylegan2/stylegan2-pytorch/op/upfirdn2d.cpp b/models/stylegan2/stylegan2-pytorch/op/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/models/stylegan2/stylegan2-pytorch/op/upfirdn2d.py b/models/stylegan2/stylegan2-pytorch/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca1f5c72098debfb0ffa1ba1b81eb92eb64d428 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/upfirdn2d.py @@ -0,0 +1,198 @@ +import os +import platform + +import torch +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + +use_fallback = False + +# Try loading precompiled, otherwise use native fallback +try: + import upfirdn2d_op +except ModuleNotFoundError as e: + print('StyleGAN2: Optimized CUDA op UpFirDn2d not available, using native PyTorch fallback.') + use_fallback = True + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if use_fallback or input.device.type == "cpu": + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + else: + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/models/stylegan2/stylegan2-pytorch/op/upfirdn2d_kernel.cu b/models/stylegan2/stylegan2-pytorch/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a88bc7720da6cd54fccd0c4a03dd20fde85c063d --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/op/upfirdn2d_kernel.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} \ No newline at end of file diff --git a/models/stylegan2/stylegan2-pytorch/ppl.py b/models/stylegan2/stylegan2-pytorch/ppl.py new file mode 100644 index 0000000000000000000000000000000000000000..6b185c894ba719701baa6ac348e743a003ec5f27 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/ppl.py @@ -0,0 +1,104 @@ +import argparse + +import torch +from torch.nn import functional as F +import numpy as np +from tqdm import tqdm + +import lpips +from model import Generator + + +def normalize(x): + return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) + + +def slerp(a, b, t): + a = normalize(a) + b = normalize(b) + d = (a * b).sum(-1, keepdim=True) + p = t * torch.acos(d) + c = normalize(b - d * a) + d = a * torch.cos(p) + c * torch.sin(p) + + return normalize(d) + + +def lerp(a, b, t): + return a + (b - a) * t + + +if __name__ == '__main__': + device = 'cuda' + + parser = argparse.ArgumentParser() + + parser.add_argument('--space', choices=['z', 'w']) + parser.add_argument('--batch', type=int, default=64) + parser.add_argument('--n_sample', type=int, default=5000) + parser.add_argument('--size', type=int, default=256) + parser.add_argument('--eps', type=float, default=1e-4) + parser.add_argument('--crop', action='store_true') + parser.add_argument('ckpt', metavar='CHECKPOINT') + + args = parser.parse_args() + + latent_dim = 512 + + ckpt = torch.load(args.ckpt) + + g = Generator(args.size, latent_dim, 8).to(device) + g.load_state_dict(ckpt['g_ema']) + g.eval() + + percept = lpips.PerceptualLoss( + model='net-lin', net='vgg', use_gpu=device.startswith('cuda') + ) + + distances = [] + + n_batch = args.n_sample // args.batch + resid = args.n_sample - (n_batch * args.batch) + batch_sizes = [args.batch] * n_batch + [resid] + + with torch.no_grad(): + for batch in tqdm(batch_sizes): + noise = g.make_noise() + + inputs = torch.randn([batch * 2, latent_dim], device=device) + lerp_t = torch.rand(batch, device=device) + + if args.space == 'w': + latent = g.get_latent(inputs) + latent_t0, latent_t1 = latent[::2], latent[1::2] + latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) + latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) + latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) + + image, _ = g([latent_e], input_is_latent=True, noise=noise) + + if args.crop: + c = image.shape[2] // 8 + image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] + + factor = image.shape[2] // 256 + + if factor > 1: + image = F.interpolate( + image, size=(256, 256), mode='bilinear', align_corners=False + ) + + dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / ( + args.eps ** 2 + ) + distances.append(dist.to('cpu').numpy()) + + distances = np.concatenate(distances, 0) + + lo = np.percentile(distances, 1, interpolation='lower') + hi = np.percentile(distances, 99, interpolation='higher') + filtered_dist = np.extract( + np.logical_and(lo <= distances, distances <= hi), distances + ) + + print('ppl:', filtered_dist.mean()) diff --git a/models/stylegan2/stylegan2-pytorch/prepare_data.py b/models/stylegan2/stylegan2-pytorch/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..db49cbda14aca3b2bc0268a4f40cd97f2dd603cc --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/prepare_data.py @@ -0,0 +1,82 @@ +import argparse +from io import BytesIO +import multiprocessing +from functools import partial + +from PIL import Image +import lmdb +from tqdm import tqdm +from torchvision import datasets +from torchvision.transforms import functional as trans_fn + + +def resize_and_convert(img, size, resample, quality=100): + img = trans_fn.resize(img, size, resample) + img = trans_fn.center_crop(img, size) + buffer = BytesIO() + img.save(buffer, format='jpeg', quality=quality) + val = buffer.getvalue() + + return val + + +def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100): + imgs = [] + + for size in sizes: + imgs.append(resize_and_convert(img, size, resample, quality)) + + return imgs + + +def resize_worker(img_file, sizes, resample): + i, file = img_file + img = Image.open(file) + img = img.convert('RGB') + out = resize_multiple(img, sizes=sizes, resample=resample) + + return i, out + + +def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS): + resize_fn = partial(resize_worker, sizes=sizes, resample=resample) + + files = sorted(dataset.imgs, key=lambda x: x[0]) + files = [(i, file) for i, (file, label) in enumerate(files)] + total = 0 + + with multiprocessing.Pool(n_worker) as pool: + for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): + for size, img in zip(sizes, imgs): + key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') + + with env.begin(write=True) as txn: + txn.put(key, img) + + total += 1 + + with env.begin(write=True) as txn: + txn.put('length'.encode('utf-8'), str(total).encode('utf-8')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--out', type=str) + parser.add_argument('--size', type=str, default='128,256,512,1024') + parser.add_argument('--n_worker', type=int, default=8) + parser.add_argument('--resample', type=str, default='lanczos') + parser.add_argument('path', type=str) + + args = parser.parse_args() + + resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR} + resample = resample_map[args.resample] + + sizes = [int(s.strip()) for s in args.size.split(',')] + + print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes)) + + imgset = datasets.ImageFolder(args.path) + + with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: + prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) diff --git a/models/stylegan2/stylegan2-pytorch/projector.py b/models/stylegan2/stylegan2-pytorch/projector.py new file mode 100644 index 0000000000000000000000000000000000000000..d63ad3573696cc22640cbeddc197d8cb15c52977 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/projector.py @@ -0,0 +1,203 @@ +import argparse +import math +import os + +import torch +from torch import optim +from torch.nn import functional as F +from torchvision import transforms +from PIL import Image +from tqdm import tqdm + +import lpips +from model import Generator + + +def noise_regularize(noises): + loss = 0 + + for noise in noises: + size = noise.shape[2] + + while True: + loss = ( + loss + + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) + + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) + ) + + if size <= 8: + break + + noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2]) + noise = noise.mean([3, 5]) + size //= 2 + + return loss + + +def noise_normalize_(noises): + for noise in noises: + mean = noise.mean() + std = noise.std() + + noise.data.add_(-mean).div_(std) + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + + return initial_lr * lr_ramp + + +def latent_noise(latent, strength): + noise = torch.randn_like(latent) * strength + + return latent + noise + + +def make_image(tensor): + return ( + tensor.detach() + .clamp_(min=-1, max=1) + .add(1) + .div_(2) + .mul(255) + .type(torch.uint8) + .permute(0, 2, 3, 1) + .to('cpu') + .numpy() + ) + + +if __name__ == '__main__': + device = 'cuda' + + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt', type=str, required=True) + parser.add_argument('--size', type=int, default=256) + parser.add_argument('--lr_rampup', type=float, default=0.05) + parser.add_argument('--lr_rampdown', type=float, default=0.25) + parser.add_argument('--lr', type=float, default=0.1) + parser.add_argument('--noise', type=float, default=0.05) + parser.add_argument('--noise_ramp', type=float, default=0.75) + parser.add_argument('--step', type=int, default=1000) + parser.add_argument('--noise_regularize', type=float, default=1e5) + parser.add_argument('--mse', type=float, default=0) + parser.add_argument('--w_plus', action='store_true') + parser.add_argument('files', metavar='FILES', nargs='+') + + args = parser.parse_args() + + n_mean_latent = 10000 + + resize = min(args.size, 256) + + transform = transforms.Compose( + [ + transforms.Resize(resize), + transforms.CenterCrop(resize), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + imgs = [] + + for imgfile in args.files: + img = transform(Image.open(imgfile).convert('RGB')) + imgs.append(img) + + imgs = torch.stack(imgs, 0).to(device) + + g_ema = Generator(args.size, 512, 8) + g_ema.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False) + g_ema.eval() + g_ema = g_ema.to(device) + + with torch.no_grad(): + noise_sample = torch.randn(n_mean_latent, 512, device=device) + latent_out = g_ema.style(noise_sample) + + latent_mean = latent_out.mean(0) + latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 + + percept = lpips.PerceptualLoss( + model='net-lin', net='vgg', use_gpu=device.startswith('cuda') + ) + + noises = g_ema.make_noise() + + latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(2, 1) + + if args.w_plus: + latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) + + latent_in.requires_grad = True + + for noise in noises: + noise.requires_grad = True + + optimizer = optim.Adam([latent_in] + noises, lr=args.lr) + + pbar = tqdm(range(args.step)) + latent_path = [] + + for i in pbar: + t = i / args.step + lr = get_lr(t, args.lr) + optimizer.param_groups[0]['lr'] = lr + noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 + latent_n = latent_noise(latent_in, noise_strength.item()) + + img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises) + + batch, channel, height, width = img_gen.shape + + if height > 256: + factor = height // 256 + + img_gen = img_gen.reshape( + batch, channel, height // factor, factor, width // factor, factor + ) + img_gen = img_gen.mean([3, 5]) + + p_loss = percept(img_gen, imgs).sum() + n_loss = noise_regularize(noises) + mse_loss = F.mse_loss(img_gen, imgs) + + loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + noise_normalize_(noises) + + if (i + 1) % 100 == 0: + latent_path.append(latent_in.detach().clone()) + + pbar.set_description( + ( + f'perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};' + f' mse: {mse_loss.item():.4f}; lr: {lr:.4f}' + ) + ) + + result_file = {'noises': noises} + + img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises) + + filename = os.path.splitext(os.path.basename(args.files[0]))[0] + '.pt' + + img_ar = make_image(img_gen) + + for i, input_name in enumerate(args.files): + result_file[input_name] = {'img': img_gen[i], 'latent': latent_in[i]} + img_name = os.path.splitext(os.path.basename(input_name))[0] + '-project.png' + pil_img = Image.fromarray(img_ar[i]) + pil_img.save(img_name) + + torch.save(result_file, filename) diff --git a/models/stylegan2/stylegan2-pytorch/sample/.gitignore b/models/stylegan2/stylegan2-pytorch/sample/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e33609d251c814ccd3a30337c965a875645c2117 --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/sample/.gitignore @@ -0,0 +1 @@ +*.png diff --git a/models/stylegan2/stylegan2-pytorch/train.py b/models/stylegan2/stylegan2-pytorch/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7295f159b0427aef89a5944a0d1eb4c23ee85a7f --- /dev/null +++ b/models/stylegan2/stylegan2-pytorch/train.py @@ -0,0 +1,413 @@ +import argparse +import math +import random +import os + +import numpy as np +import torch +from torch import nn, autograd, optim +from torch.nn import functional as F +from torch.utils import data +import torch.distributed as dist +from torchvision import transforms, utils +from tqdm import tqdm + +try: + import wandb + +except ImportError: + wandb = None + +from model import Generator, Discriminator +from dataset import MultiResolutionDataset +from distributed import ( + get_rank, + synchronize, + reduce_loss_dict, + reduce_sum, + get_world_size, +) + + +def data_sampler(dataset, shuffle, distributed): + if distributed: + return data.distributed.DistributedSampler(dataset, shuffle=shuffle) + + if shuffle: + return data.RandomSampler(dataset) + + else: + return data.SequentialSampler(dataset) + + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +def accumulate(model1, model2, decay=0.999): + par1 = dict(model1.named_parameters()) + par2 = dict(model2.named_parameters()) + + for k in par1.keys(): + par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) + + +def sample_data(loader): + while True: + for batch in loader: + yield batch + + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(real_pred, real_img): + grad_real, = autograd.grad( + outputs=real_pred.sum(), inputs=real_img, create_graph=True + ) + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + + return grad_penalty + + +def g_nonsaturating_loss(fake_pred): + loss = F.softplus(-fake_pred).mean() + + return loss + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt( + fake_img.shape[2] * fake_img.shape[3] + ) + grad, = autograd.grad( + outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True + ) + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_mean.detach(), path_lengths + + +def make_noise(batch, latent_dim, n_noise, device): + if n_noise == 1: + return torch.randn(batch, latent_dim, device=device) + + noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) + + return noises + + +def mixing_noise(batch, latent_dim, prob, device): + if prob > 0 and random.random() < prob: + return make_noise(batch, latent_dim, 2, device) + + else: + return [make_noise(batch, latent_dim, 1, device)] + + +def set_grad_none(model, targets): + for n, p in model.named_parameters(): + if n in targets: + p.grad = None + + +def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): + loader = sample_data(loader) + + pbar = range(args.iter) + + if get_rank() == 0: + pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) + + mean_path_length = 0 + + d_loss_val = 0 + r1_loss = torch.tensor(0.0, device=device) + g_loss_val = 0 + path_loss = torch.tensor(0.0, device=device) + path_lengths = torch.tensor(0.0, device=device) + mean_path_length_avg = 0 + loss_dict = {} + + if args.distributed: + g_module = generator.module + d_module = discriminator.module + + else: + g_module = generator + d_module = discriminator + + accum = 0.5 ** (32 / (10 * 1000)) + + sample_z = torch.randn(args.n_sample, args.latent, device=device) + + for idx in pbar: + i = idx + args.start_iter + + if i > args.iter: + print("Done!") + + break + + real_img = next(loader) + real_img = real_img.to(device) + + requires_grad(generator, False) + requires_grad(discriminator, True) + + noise = mixing_noise(args.batch, args.latent, args.mixing, device) + fake_img, _ = generator(noise) + fake_pred = discriminator(fake_img) + + real_pred = discriminator(real_img) + d_loss = d_logistic_loss(real_pred, fake_pred) + + loss_dict["d"] = d_loss + loss_dict["real_score"] = real_pred.mean() + loss_dict["fake_score"] = fake_pred.mean() + + discriminator.zero_grad() + d_loss.backward() + d_optim.step() + + d_regularize = i % args.d_reg_every == 0 + + if d_regularize: + real_img.requires_grad = True + real_pred = discriminator(real_img) + r1_loss = d_r1_loss(real_pred, real_img) + + discriminator.zero_grad() + (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() + + d_optim.step() + + loss_dict["r1"] = r1_loss + + requires_grad(generator, True) + requires_grad(discriminator, False) + + noise = mixing_noise(args.batch, args.latent, args.mixing, device) + fake_img, _ = generator(noise) + fake_pred = discriminator(fake_img) + g_loss = g_nonsaturating_loss(fake_pred) + + loss_dict["g"] = g_loss + + generator.zero_grad() + g_loss.backward() + g_optim.step() + + g_regularize = i % args.g_reg_every == 0 + + if g_regularize: + path_batch_size = max(1, args.batch // args.path_batch_shrink) + noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) + fake_img, latents = generator(noise, return_latents=True) + + path_loss, mean_path_length, path_lengths = g_path_regularize( + fake_img, latents, mean_path_length + ) + + generator.zero_grad() + weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss + + if args.path_batch_shrink: + weighted_path_loss += 0 * fake_img[0, 0, 0, 0] + + weighted_path_loss.backward() + + g_optim.step() + + mean_path_length_avg = ( + reduce_sum(mean_path_length).item() / get_world_size() + ) + + loss_dict["path"] = path_loss + loss_dict["path_length"] = path_lengths.mean() + + accumulate(g_ema, g_module, accum) + + loss_reduced = reduce_loss_dict(loss_dict) + + d_loss_val = loss_reduced["d"].mean().item() + g_loss_val = loss_reduced["g"].mean().item() + r1_val = loss_reduced["r1"].mean().item() + path_loss_val = loss_reduced["path"].mean().item() + real_score_val = loss_reduced["real_score"].mean().item() + fake_score_val = loss_reduced["fake_score"].mean().item() + path_length_val = loss_reduced["path_length"].mean().item() + + if get_rank() == 0: + pbar.set_description( + ( + f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " + f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}" + ) + ) + + if wandb and args.wandb: + wandb.log( + { + "Generator": g_loss_val, + "Discriminator": d_loss_val, + "R1": r1_val, + "Path Length Regularization": path_loss_val, + "Mean Path Length": mean_path_length, + "Real Score": real_score_val, + "Fake Score": fake_score_val, + "Path Length": path_length_val, + } + ) + + if i % 100 == 0: + with torch.no_grad(): + g_ema.eval() + sample, _ = g_ema([sample_z]) + utils.save_image( + sample, + f"sample/{str(i).zfill(6)}.png", + nrow=int(args.n_sample ** 0.5), + normalize=True, + range=(-1, 1), + ) + + if i % 10000 == 0: + torch.save( + { + "g": g_module.state_dict(), + "d": d_module.state_dict(), + "g_ema": g_ema.state_dict(), + "g_optim": g_optim.state_dict(), + "d_optim": d_optim.state_dict(), + }, + f"checkpoint/{str(i).zfill(6)}.pt", + ) + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser() + + parser.add_argument("path", type=str) + parser.add_argument("--iter", type=int, default=800000) + parser.add_argument("--batch", type=int, default=16) + parser.add_argument("--n_sample", type=int, default=64) + parser.add_argument("--size", type=int, default=256) + parser.add_argument("--r1", type=float, default=10) + parser.add_argument("--path_regularize", type=float, default=2) + parser.add_argument("--path_batch_shrink", type=int, default=2) + parser.add_argument("--d_reg_every", type=int, default=16) + parser.add_argument("--g_reg_every", type=int, default=4) + parser.add_argument("--mixing", type=float, default=0.9) + parser.add_argument("--ckpt", type=str, default=None) + parser.add_argument("--lr", type=float, default=0.002) + parser.add_argument("--channel_multiplier", type=int, default=2) + parser.add_argument("--wandb", action="store_true") + parser.add_argument("--local_rank", type=int, default=0) + + args = parser.parse_args() + + n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + args.distributed = n_gpu > 1 + + if args.distributed: + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend="nccl", init_method="env://") + synchronize() + + args.latent = 512 + args.n_mlp = 8 + + args.start_iter = 0 + + generator = Generator( + args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier + ).to(device) + discriminator = Discriminator( + args.size, channel_multiplier=args.channel_multiplier + ).to(device) + g_ema = Generator( + args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier + ).to(device) + g_ema.eval() + accumulate(g_ema, generator, 0) + + g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) + d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) + + g_optim = optim.Adam( + generator.parameters(), + lr=args.lr * g_reg_ratio, + betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), + ) + d_optim = optim.Adam( + discriminator.parameters(), + lr=args.lr * d_reg_ratio, + betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), + ) + + if args.ckpt is not None: + print("load model:", args.ckpt) + + ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) + + try: + ckpt_name = os.path.basename(args.ckpt) + args.start_iter = int(os.path.splitext(ckpt_name)[0]) + + except ValueError: + pass + + generator.load_state_dict(ckpt["g"]) + discriminator.load_state_dict(ckpt["d"]) + g_ema.load_state_dict(ckpt["g_ema"]) + + g_optim.load_state_dict(ckpt["g_optim"]) + d_optim.load_state_dict(ckpt["d_optim"]) + + if args.distributed: + generator = nn.parallel.DistributedDataParallel( + generator, + device_ids=[args.local_rank], + output_device=args.local_rank, + broadcast_buffers=False, + ) + + discriminator = nn.parallel.DistributedDataParallel( + discriminator, + device_ids=[args.local_rank], + output_device=args.local_rank, + broadcast_buffers=False, + ) + + transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), + ] + ) + + dataset = MultiResolutionDataset(args.path, transform, args.size) + loader = data.DataLoader( + dataset, + batch_size=args.batch, + sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), + drop_last=True, + ) + + if get_rank() == 0 and wandb is not None and args.wandb: + wandb.init(project="stylegan 2") + + train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) diff --git a/models/wrappers.py b/models/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..335321bc67e7b3c7f1e715948e967388c3be05f9 --- /dev/null +++ b/models/wrappers.py @@ -0,0 +1,737 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import torch +import numpy as np +import re +import os +import random +from pathlib import Path +from types import SimpleNamespace +from utils import download_ckpt +from config import Config +from netdissect import proggan, zdataset +from . import biggan +from . import stylegan +from . import stylegan2 +from abc import abstractmethod, ABC as AbstractBaseClass +from functools import singledispatch + +class BaseModel(AbstractBaseClass, torch.nn.Module): + + # Set parameters for identifying model from instance + def __init__(self, model_name, class_name): + super(BaseModel, self).__init__() + self.model_name = model_name + self.outclass = class_name + + # Stop model evaluation as soon as possible after + # given layer has been executed, used to speed up + # netdissect.InstrumentedModel::retain_layer(). + # Validate with tests/partial_forward_test.py + # Can use forward() as fallback at the cost of performance. + @abstractmethod + def partial_forward(self, x, layer_name): + pass + + # Generate batch of latent vectors + @abstractmethod + def sample_latent(self, n_samples=1, seed=None, truncation=None): + pass + + # Maximum number of latents that can be provided + # Typically one for each layer + def get_max_latents(self): + return 1 + + # Name of primary latent space + # E.g. StyleGAN can alternatively use W + def latent_space_name(self): + return 'Z' + + def get_latent_shape(self): + return tuple(self.sample_latent(1).shape) + + def get_latent_dims(self): + return np.prod(self.get_latent_shape()) + + def set_output_class(self, new_class): + self.outclass = new_class + + # Map from typical range [-1, 1] to [0, 1] + def forward(self, x): + out = self.model.forward(x) + return 0.5*(out+1) + + # Generate images and convert to numpy + def sample_np(self, z=None, n_samples=1, seed=None): + if z is None: + z = self.sample_latent(n_samples, seed=seed) + elif isinstance(z, list): + z = [torch.tensor(l).to(self.device) if not torch.is_tensor(l) else l for l in z] + elif not torch.is_tensor(z): + z = torch.tensor(z).to(self.device) + img = self.forward(z) + img_np = img.permute(0, 2, 3, 1).cpu().detach().numpy() + return np.clip(img_np, 0.0, 1.0).squeeze() + + # For models that use part of latent as conditioning + def get_conditional_state(self, z): + return None + + # For models that use part of latent as conditioning + def set_conditional_state(self, z, c): + return z + + def named_modules(self, *args, **kwargs): + return self.model.named_modules(*args, **kwargs) + +# PyTorch port of StyleGAN 2 +class StyleGAN2(BaseModel): + def __init__(self, device, class_name, truncation=1.0, use_w=False): + super(StyleGAN2, self).__init__('StyleGAN2', class_name or 'ffhq') + self.device = device + self.truncation = truncation + self.latent_avg = None + self.w_primary = use_w # use W as primary latent space? + + # Image widths + configs = { + # Converted NVIDIA official + 'ffhq': 1024, + 'car': 512, + 'cat': 256, + 'church': 256, + 'horse': 256, + # Tuomas + 'bedrooms': 256, + 'kitchen': 256, + 'places': 256, + 'lookbook': 512 + } + + assert self.outclass in configs, \ + f'Invalid StyleGAN2 class {self.outclass}, should be one of [{", ".join(configs.keys())}]' + + self.resolution = configs[self.outclass] + self.name = f'StyleGAN2-{self.outclass}' + self.has_latent_residual = True + self.load_model() + self.set_noise_seed(0) + + def latent_space_name(self): + return 'W' if self.w_primary else 'Z' + + def use_w(self): + self.w_primary = True + + def use_z(self): + self.w_primary = False + + # URLs created with https://sites.google.com/site/gdocs2direct/ + def download_checkpoint(self, outfile): + checkpoints = { + 'horse': 'https://drive.google.com/uc?export=download&id=18SkqWAkgt0fIwDEf2pqeaenNi4OoCo-0', + 'ffhq': 'https://drive.google.com/uc?export=download&id=1FJRwzAkV-XWbxgTwxEmEACvuqF5DsBiV', + 'church': 'https://drive.google.com/uc?export=download&id=1HFM694112b_im01JT7wop0faftw9ty5g', + 'car': 'https://drive.google.com/uc?export=download&id=1iRoWclWVbDBAy5iXYZrQnKYSbZUqXI6y', + 'cat': 'https://drive.google.com/uc?export=download&id=15vJP8GDr0FlRYpE8gD7CdeEz2mXrQMgN', + 'places': 'https://drive.google.com/uc?export=download&id=1X8-wIH3aYKjgDZt4KMOtQzN1m4AlCVhm', + 'bedrooms': 'https://drive.google.com/uc?export=download&id=1nZTW7mjazs-qPhkmbsOLLA_6qws-eNQu', + 'kitchen': 'https://drive.google.com/uc?export=download&id=15dCpnZ1YLAnETAPB0FGmXwdBclbwMEkZ', + 'lookbook': 'https://drive.google.com/uc?export=download&id=1-F-RMkbHUv_S_k-_olh43mu5rDUMGYKe' + } + + url = checkpoints[self.outclass] + download_ckpt(url, outfile) + + def load_model(self): + checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') + checkpoint = Path(checkpoint_root) / f'stylegan2/stylegan2_{self.outclass}_{self.resolution}.pt' + + self.model = stylegan2.Generator(self.resolution, 512, 8).to(self.device) + + if not checkpoint.is_file(): + os.makedirs(checkpoint.parent, exist_ok=True) + self.download_checkpoint(checkpoint) + + ckpt = torch.load(checkpoint) + self.model.load_state_dict(ckpt['g_ema'], strict=False) + self.latent_avg = 0 + + def sample_latent(self, n_samples=1, seed=None, truncation=None): + if seed is None: + seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state + + rng = np.random.RandomState(seed) + z = torch.from_numpy( + rng.standard_normal(512 * n_samples) + .reshape(n_samples, 512)).float().to(self.device) #[N, 512] + + if self.w_primary: + z = self.model.style(z) + + return z + + def get_max_latents(self): + return self.model.n_latent + + def set_output_class(self, new_class): + if self.outclass != new_class: + raise RuntimeError('StyleGAN2: cannot change output class without reloading') + + def forward(self, x): + x = x if isinstance(x, list) else [x] + out, _ = self.model(x, noise=self.noise, + truncation=self.truncation, truncation_latent=self.latent_avg, input_is_w=self.w_primary) + return 0.5*(out+1) + + def partial_forward(self, x, layer_name): + styles = x if isinstance(x, list) else [x] + inject_index = None + noise = self.noise + + if not self.w_primary: + styles = [self.model.style(s) for s in styles] + + if len(styles) == 1: + # One global latent + inject_index = self.model.n_latent + latent = self.model.strided_style(styles[0].unsqueeze(1).repeat(1, inject_index, 1)) # [N, 18, 512] + elif len(styles) == 2: + # Latent mixing with two latents + if inject_index is None: + inject_index = random.randint(1, self.model.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.model.n_latent - inject_index, 1) + + latent = self.model.strided_style(torch.cat([latent, latent2], 1)) + else: + # One latent per layer + assert len(styles) == self.model.n_latent, f'Expected {self.model.n_latents} latents, got {len(styles)}' + styles = torch.stack(styles, dim=1) # [N, 18, 512] + latent = self.model.strided_style(styles) + + if 'style' in layer_name: + return + + out = self.model.input(latent) + if 'input' == layer_name: + return + + out = self.model.conv1(out, latent[:, 0], noise=noise[0]) + if 'conv1' in layer_name: + return + + skip = self.model.to_rgb1(out, latent[:, 1]) + if 'to_rgb1' in layer_name: + return + + i = 1 + noise_i = 1 + + for conv1, conv2, to_rgb in zip( + self.model.convs[::2], self.model.convs[1::2], self.model.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise[noise_i]) + if f'convs.{i-1}' in layer_name: + return + + out = conv2(out, latent[:, i + 1], noise=noise[noise_i + 1]) + if f'convs.{i}' in layer_name: + return + + skip = to_rgb(out, latent[:, i + 2], skip) + if f'to_rgbs.{i//2}' in layer_name: + return + + i += 2 + noise_i += 2 + + image = skip + + raise RuntimeError(f'Layer {layer_name} not encountered in partial_forward') + + def set_noise_seed(self, seed): + torch.manual_seed(seed) + self.noise = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=self.device)] + + for i in range(3, self.model.log_size + 1): + for _ in range(2): + self.noise.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=self.device)) + +# PyTorch port of StyleGAN 1 +class StyleGAN(BaseModel): + def __init__(self, device, class_name, truncation=1.0, use_w=False): + super(StyleGAN, self).__init__('StyleGAN', class_name or 'ffhq') + self.device = device + self.w_primary = use_w # is W primary latent space? + + configs = { + # Official + 'ffhq': 1024, + 'celebahq': 1024, + 'bedrooms': 256, + 'cars': 512, + 'cats': 256, + + # From https://github.com/justinpinkney/awesome-pretrained-stylegan + 'vases': 1024, + 'wikiart': 512, + 'fireworks': 512, + 'abstract': 512, + 'anime': 512, + 'ukiyo-e': 512, + } + + assert self.outclass in configs, \ + f'Invalid StyleGAN class {self.outclass}, should be one of [{", ".join(configs.keys())}]' + + self.resolution = configs[self.outclass] + self.name = f'StyleGAN-{self.outclass}' + self.has_latent_residual = True + self.load_model() + self.set_noise_seed(0) + + def latent_space_name(self): + return 'W' if self.w_primary else 'Z' + + def use_w(self): + self.w_primary = True + + def use_z(self): + self.w_primary = False + + def load_model(self): + checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') + checkpoint = Path(checkpoint_root) / f'stylegan/stylegan_{self.outclass}_{self.resolution}.pt' + + self.model = stylegan.StyleGAN_G(self.resolution).to(self.device) + + urls_tf = { + 'vases': 'https://thisvesseldoesnotexist.s3-us-west-2.amazonaws.com/public/network-snapshot-008980.pkl', + 'fireworks': 'https://mega.nz/#!7uBHnACY!quIW-pjdDa7NqnZOYh1z5UemWwPOW6HkYSoJ4usCg9U', + 'abstract': 'https://mega.nz/#!vCQyHQZT!zdeOg3VvT4922Z2UfxO51xgAfJD-NAK2nW7H_jMlilU', + 'anime': 'https://mega.nz/#!vawjXISI!F7s13yRicxDA3QYqYDL2kjnc2K7Zk3DwCIYETREmBP4', + 'ukiyo-e': 'https://drive.google.com/uc?id=1CHbJlci9NhVFifNQb3vCGu6zw4eqzvTd', + } + + urls_torch = { + 'celebahq': 'https://drive.google.com/uc?export=download&id=1lGcRwNoXy_uwXkD6sy43aAa-rMHRR7Ad', + 'bedrooms': 'https://drive.google.com/uc?export=download&id=1r0_s83-XK2dKlyY3WjNYsfZ5-fnH8QgI', + 'ffhq': 'https://drive.google.com/uc?export=download&id=1GcxTcLDPYxQqcQjeHpLUutGzwOlXXcks', + 'cars': 'https://drive.google.com/uc?export=download&id=1aaUXHRHjQ9ww91x4mtPZD0w50fsIkXWt', + 'cats': 'https://drive.google.com/uc?export=download&id=1JzA5iiS3qPrztVofQAjbb0N4xKdjOOyV', + 'wikiart': 'https://drive.google.com/uc?export=download&id=1fN3noa7Rsl9slrDXsgZVDsYFxV0O08Vx', + } + + if not checkpoint.is_file(): + os.makedirs(checkpoint.parent, exist_ok=True) + if self.outclass in urls_torch: + download_ckpt(urls_torch[self.outclass], checkpoint) + else: + checkpoint_tf = checkpoint.with_suffix('.pkl') + if not checkpoint_tf.is_file(): + download_ckpt(urls_tf[self.outclass], checkpoint_tf) + print('Converting TensorFlow checkpoint to PyTorch') + self.model.export_from_tf(checkpoint_tf) + + self.model.load_weights(checkpoint) + + def sample_latent(self, n_samples=1, seed=None, truncation=None): + if seed is None: + seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state + + rng = np.random.RandomState(seed) + noise = torch.from_numpy( + rng.standard_normal(512 * n_samples) + .reshape(n_samples, 512)).float().to(self.device) #[N, 512] + + if self.w_primary: + noise = self.model._modules['g_mapping'].forward(noise) + + return noise + + def get_max_latents(self): + return 18 + + def set_output_class(self, new_class): + if self.outclass != new_class: + raise RuntimeError('StyleGAN: cannot change output class without reloading') + + def forward(self, x): + out = self.model.forward(x, latent_is_w=self.w_primary) + return 0.5*(out+1) + + # Run model only until given layer + def partial_forward(self, x, layer_name): + mapping = self.model._modules['g_mapping'] + G = self.model._modules['g_synthesis'] + trunc = self.model._modules.get('truncation', lambda x : x) + + if not self.w_primary: + x = mapping.forward(x) # handles list inputs + + if isinstance(x, list): + x = torch.stack(x, dim=1) + else: + x = x.unsqueeze(1).expand(-1, 18, -1) + + # Whole mapping + if 'g_mapping' in layer_name: + return + + x = trunc(x) + if layer_name == 'truncation': + return + + # Get names of children + def iterate(m, name, seen): + children = getattr(m, '_modules', []) + if len(children) > 0: + for child_name, module in children.items(): + seen += iterate(module, f'{name}.{child_name}', seen) + return seen + else: + return [name] + + # Generator + batch_size = x.size(0) + for i, (n, m) in enumerate(G.blocks.items()): # InputBlock or GSynthesisBlock + if i == 0: + r = m(x[:, 2*i:2*i+2]) + else: + r = m(r, x[:, 2*i:2*i+2]) + + children = iterate(m, f'g_synthesis.blocks.{n}', []) + for c in children: + if layer_name in c: # substring + return + + raise RuntimeError(f'Layer {layer_name} not encountered in partial_forward') + + + def set_noise_seed(self, seed): + G = self.model._modules['g_synthesis'] + + def for_each_child(this, name, func): + children = getattr(this, '_modules', []) + for child_name, module in children.items(): + for_each_child(module, f'{name}.{child_name}', func) + func(this, name) + + def modify(m, name): + if isinstance(m, stylegan.NoiseLayer): + H, W = [int(s) for s in name.split('.')[2].split('x')] + torch.random.manual_seed(seed) + m.noise = torch.randn(1, 1, H, W, device=self.device, dtype=torch.float32) + #m.noise = 1.0 # should be [N, 1, H, W], but this also works + + for_each_child(G, 'g_synthesis', modify) + +class GANZooModel(BaseModel): + def __init__(self, device, model_name): + super(GANZooModel, self).__init__(model_name, 'default') + self.device = device + self.base_model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', + model_name, pretrained=True, useGPU=(device.type == 'cuda')) + self.model = self.base_model.netG.to(self.device) + self.name = model_name + self.has_latent_residual = False + + def sample_latent(self, n_samples=1, seed=0, truncation=None): + # Uses torch.randn + noise, _ = self.base_model.buildNoiseData(n_samples) + return noise + + # Don't bother for now + def partial_forward(self, x, layer_name): + return self.forward(x) + + def get_conditional_state(self, z): + return z[:, -20:] # last 20 = conditioning + + def set_conditional_state(self, z, c): + z[:, -20:] = c + return z + + def forward(self, x): + out = self.base_model.test(x) + return 0.5*(out+1) + + +class ProGAN(BaseModel): + def __init__(self, device, lsun_class=None): + super(ProGAN, self).__init__('ProGAN', lsun_class) + self.device = device + + # These are downloaded by GANDissect + valid_classes = [ 'bedroom', 'churchoutdoor', 'conferenceroom', 'diningroom', 'kitchen', 'livingroom', 'restaurant' ] + assert self.outclass in valid_classes, \ + f'Invalid LSUN class {self.outclass}, should be one of {valid_classes}' + + self.load_model() + self.name = f'ProGAN-{self.outclass}' + self.has_latent_residual = False + + def load_model(self): + checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') + checkpoint = Path(checkpoint_root) / f'progan/{self.outclass}_lsun.pth' + + if not checkpoint.is_file(): + os.makedirs(checkpoint.parent, exist_ok=True) + url = f'http://netdissect.csail.mit.edu/data/ganmodel/karras/{self.outclass}_lsun.pth' + download_ckpt(url, checkpoint) + + self.model = proggan.from_pth_file(str(checkpoint.resolve())).to(self.device) + + def sample_latent(self, n_samples=1, seed=None, truncation=None): + if seed is None: + seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state + noise = zdataset.z_sample_for_model(self.model, n_samples, seed=seed)[...] + return noise.to(self.device) + + def forward(self, x): + if isinstance(x, list): + assert len(x) == 1, "ProGAN only supports a single global latent" + x = x[0] + + out = self.model.forward(x) + return 0.5*(out+1) + + # Run model only until given layer + def partial_forward(self, x, layer_name): + assert isinstance(self.model, torch.nn.Sequential), 'Expected sequential model' + + if isinstance(x, list): + assert len(x) == 1, "ProGAN only supports a single global latent" + x = x[0] + + x = x.view(x.shape[0], x.shape[1], 1, 1) + for name, module in self.model._modules.items(): # ordered dict + x = module(x) + if name == layer_name: + return + + raise RuntimeError(f'Layer {layer_name} not encountered in partial_forward') + + +class BigGAN(BaseModel): + def __init__(self, device, resolution, class_name, truncation=1.0): + super(BigGAN, self).__init__(f'BigGAN-{resolution}', class_name) + self.device = device + self.truncation = truncation + self.load_model(f'biggan-deep-{resolution}') + self.set_output_class(class_name or 'husky') + self.name = f'BigGAN-{resolution}-{self.outclass}-t{self.truncation}' + self.has_latent_residual = True + + # Default implementaiton fails without an internet + # connection, even if the model has been cached + def load_model(self, name): + if name not in biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP: + raise RuntimeError('Unknown BigGAN model name', name) + + checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') + model_path = Path(checkpoint_root) / name + + os.makedirs(model_path, exist_ok=True) + + model_file = model_path / biggan.model.WEIGHTS_NAME + config_file = model_path / biggan.model.CONFIG_NAME + model_url = biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP[name] + config_url = biggan.model.PRETRAINED_CONFIG_ARCHIVE_MAP[name] + + for filename, url in ((model_file, model_url), (config_file, config_url)): + if not filename.is_file(): + print('Downloading', url) + with open(filename, 'wb') as f: + if url.startswith("s3://"): + biggan.s3_get(url, f) + else: + biggan.http_get(url, f) + + self.model = biggan.BigGAN.from_pretrained(model_path).to(self.device) + + def sample_latent(self, n_samples=1, truncation=None, seed=None): + if seed is None: + seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state + + noise_vector = biggan.truncated_noise_sample(truncation=truncation or self.truncation, batch_size=n_samples, seed=seed) + noise = torch.from_numpy(noise_vector) #[N, 128] + + return noise.to(self.device) + + # One extra for gen_z + def get_max_latents(self): + return len(self.model.config.layers) + 1 + + def get_conditional_state(self, z): + return self.v_class + + def set_conditional_state(self, z, c): + self.v_class = c + + def is_valid_class(self, class_id): + if isinstance(class_id, int): + return class_id < 1000 + elif isinstance(class_id, str): + return biggan.one_hot_from_names([class_id.replace(' ', '_')]) is not None + else: + raise RuntimeError(f'Unknown class identifier {class_id}') + + def set_output_class(self, class_id): + if isinstance(class_id, int): + self.v_class = torch.from_numpy(biggan.one_hot_from_int([class_id])).to(self.device) + self.outclass = f'class{class_id}' + elif isinstance(class_id, str): + self.outclass = class_id.replace(' ', '_') + self.v_class = torch.from_numpy(biggan.one_hot_from_names([class_id])).to(self.device) + else: + raise RuntimeError(f'Unknown class identifier {class_id}') + + def forward(self, x): + # Duplicate along batch dimension + if isinstance(x, list): + c = self.v_class.repeat(x[0].shape[0], 1) + class_vector = len(x)*[c] + else: + class_vector = self.v_class.repeat(x.shape[0], 1) + out = self.model.forward(x, class_vector, self.truncation) # [N, 3, 128, 128], in [-1, 1] + return 0.5*(out+1) + + # Run model only until given layer + # Used to speed up PCA sample collection + def partial_forward(self, x, layer_name): + if layer_name in ['embeddings', 'generator.gen_z']: + n_layers = 0 + elif 'generator.layers' in layer_name: + layer_base = re.match('^generator\.layers\.[0-9]+', layer_name)[0] + n_layers = int(layer_base.split('.')[-1]) + 1 + else: + n_layers = len(self.model.config.layers) + + if not isinstance(x, list): + x = self.model.n_latents*[x] + + if isinstance(self.v_class, list): + labels = [c.repeat(x[0].shape[0], 1) for c in class_label] + embed = [self.model.embeddings(l) for l in labels] + else: + class_label = self.v_class.repeat(x[0].shape[0], 1) + embed = len(x)*[self.model.embeddings(class_label)] + + assert len(x) == self.model.n_latents, f'Expected {self.model.n_latents} latents, got {len(x)}' + assert len(embed) == self.model.n_latents, f'Expected {self.model.n_latents} class vectors, got {len(class_label)}' + + cond_vectors = [torch.cat((z, e), dim=1) for (z, e) in zip(x, embed)] + + # Generator forward + z = self.model.generator.gen_z(cond_vectors[0]) + z = z.view(-1, 4, 4, 16 * self.model.generator.config.channel_width) + z = z.permute(0, 3, 1, 2).contiguous() + + cond_idx = 1 + for i, layer in enumerate(self.model.generator.layers[:n_layers]): + if isinstance(layer, biggan.GenBlock): + z = layer(z, cond_vectors[cond_idx], self.truncation) + cond_idx += 1 + else: + z = layer(z) + + return None + +# Version 1: separate parameters +@singledispatch +def get_model(name, output_class, device, **kwargs): + # Check if optionally provided existing model can be reused + inst = kwargs.get('inst', None) + model = kwargs.get('model', None) + + if inst or model: + cached = model or inst.model + + network_same = (cached.model_name == name) + outclass_same = (cached.outclass == output_class) + can_change_class = ('BigGAN' in name) + + if network_same and (outclass_same or can_change_class): + cached.set_output_class(output_class) + return cached + + if name == 'DCGAN': + import warnings + warnings.filterwarnings("ignore", message="nn.functional.tanh is deprecated") + model = GANZooModel(device, 'DCGAN') + elif name == 'ProGAN': + model = ProGAN(device, output_class) + elif 'BigGAN' in name: + assert '-' in name, 'Please specify BigGAN resolution, e.g. BigGAN-512' + model = BigGAN(device, name.split('-')[-1], class_name=output_class) + elif name == 'StyleGAN': + model = StyleGAN(device, class_name=output_class) + elif name == 'StyleGAN2': + model = StyleGAN2(device, class_name=output_class) + else: + raise RuntimeError(f'Unknown model {name}') + + return model + +# Version 2: Config object +@get_model.register(Config) +def _(cfg, device, **kwargs): + kwargs['use_w'] = kwargs.get('use_w', cfg.use_w) # explicit arg can override cfg + return get_model(cfg.model, cfg.output_class, device, **kwargs) + +# Version 1: separate parameters +@singledispatch +def get_instrumented_model(name, output_class, layers, device, **kwargs): + model = get_model(name, output_class, device, **kwargs) + model.eval() + + inst = kwargs.get('inst', None) + if inst: + inst.close() + + if not isinstance(layers, list): + layers = [layers] + + # Verify given layer names + module_names = [name for (name, _) in model.named_modules()] + for layer_name in layers: + if not layer_name in module_names: + print(f"Layer '{layer_name}' not found in model!") + print("Available layers:", '\n'.join(module_names)) + raise RuntimeError(f"Unknown layer '{layer_name}''") + + # Reset StyleGANs to z mode for shape annotation + if hasattr(model, 'use_z'): + model.use_z() + + from netdissect.modelconfig import create_instrumented_model + inst = create_instrumented_model(SimpleNamespace( + model = model, + layers = layers, + cuda = device.type == 'cuda', + gen = True, + latent_shape = model.get_latent_shape() + )) + + if kwargs.get('use_w', False): + model.use_w() + + return inst + +# Version 2: Config object +@get_instrumented_model.register(Config) +def _(cfg, device, **kwargs): + kwargs['use_w'] = kwargs.get('use_w', cfg.use_w) # explicit arg can override cfg + return get_instrumented_model(cfg.model, cfg.output_class, cfg.layer, device, **kwargs) diff --git a/netdissect/LICENSE.txt b/netdissect/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..f6b098bc5c232e162b20404c33d5335799a6cc59 --- /dev/null +++ b/netdissect/LICENSE.txt @@ -0,0 +1,19 @@ +Copyright (c) 2020 Erik Härkönen. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/netdissect/__init__.py b/netdissect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39f0957560ff29b9ff0ee630e78972cd3ef187fb --- /dev/null +++ b/netdissect/__init__.py @@ -0,0 +1,60 @@ +''' +Netdissect package. + +To run dissection: + +1. Load up the convolutional model you wish to dissect, and wrap it + in an InstrumentedModel. Call imodel.retain_layers([layernames,..]) + to analyze a specified set of layers. +2. Load the segmentation dataset using the BrodenDataset class; + use the transform_image argument to normalize images to be + suitable for the model, or the size argument to truncate the dataset. +3. Write a function to recover the original image (with RGB scaled to + [0...1]) given a normalized dataset image; ReverseNormalize in this + package inverts transforms.Normalize for this purpose. +4. Choose a directory in which to write the output, and call + dissect(outdir, model, dataset). + +Example: + + from netdissect import InstrumentedModel, dissect + from netdissect import BrodenDataset, ReverseNormalize + + model = InstrumentedModel(load_my_model()) + model.eval() + model.cuda() + model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5']) + bds = BrodenDataset('dataset/broden1_227', + transform_image=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), + size=1000) + dissect('result/dissect', model, bds, + recover_image=ReverseNormalize(IMAGE_MEAN, IMAGE_STDEV), + examples_per_unit=10) +''' + +from .dissection import dissect, ReverseNormalize +from .dissection import ClassifierSegRunner, GeneratorSegRunner +from .dissection import ImageOnlySegRunner +from .broden import BrodenDataset, ScaleSegmentation, scatter_batch +from .segdata import MultiSegmentDataset +from .nethook import InstrumentedModel +from .zdataset import z_dataset_for_model, z_sample_for_model, standard_z_sample +from . import actviz +from . import progress +from . import runningstats +from . import sampler + +__all__ = [ + 'dissect', 'ReverseNormalize', + 'ClassifierSegRunner', 'GeneratorSegRunner', 'ImageOnlySegRunner', + 'BrodenDataset', 'ScaleSegmentation', 'scatter_batch', + 'MultiSegmentDataset', + 'InstrumentedModel', + 'z_dataset_for_model', 'z_sample_for_model', 'standard_z_sample' + 'actviz', + 'progress', + 'runningstats', + 'sampler' +] diff --git a/netdissect/__main__.py b/netdissect/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2bd9f630eaa0f45a6a201adcf356a1e092050cb --- /dev/null +++ b/netdissect/__main__.py @@ -0,0 +1,408 @@ +import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL +from torchvision import transforms +from torch.utils.data import TensorDataset +from netdissect.progress import verbose_progress, print_progress +from netdissect import InstrumentedModel, BrodenDataset, dissect +from netdissect import MultiSegmentDataset, GeneratorSegRunner +from netdissect import ImageOnlySegRunner +from netdissect.parallelfolder import ParallelImageFolders +from netdissect.zdataset import z_dataset_for_model +from netdissect.autoeval import autoimport_eval +from netdissect.modelconfig import create_instrumented_model +from netdissect.pidfile import exit_if_job_done, mark_job_done + +help_epilog = '''\ +Example: to dissect three layers of the pretrained alexnet in torchvision: + +python -m netdissect \\ + --model "torchvision.models.alexnet(pretrained=True)" \\ + --layers features.6:conv3 features.8:conv4 features.10:conv5 \\ + --imgsize 227 \\ + --outdir dissect/alexnet-imagenet + +To dissect a progressive GAN model: + +python -m netdissect \\ + --model "proggan.from_pth_file('model/churchoutdoor.pth')" \\ + --gan +''' + +def main(): + # Training settings + def strpair(arg): + p = tuple(arg.split(':')) + if len(p) == 1: + p = p + p + return p + def intpair(arg): + p = arg.split(',') + if len(p) == 1: + p = p + p + return tuple(int(v) for v in p) + + parser = argparse.ArgumentParser(description='Net dissect utility', + prog='python -m netdissect', + epilog=textwrap.dedent(help_epilog), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument('--model', type=str, default=None, + help='constructor for the model to test') + parser.add_argument('--pthfile', type=str, default=None, + help='filename of .pth file for the model') + parser.add_argument('--unstrict', action='store_true', default=False, + help='ignore unexpected pth parameters') + parser.add_argument('--submodule', type=str, default=None, + help='submodule to load from pthfile') + parser.add_argument('--outdir', type=str, default='dissect', + help='directory for dissection output') + parser.add_argument('--layers', type=strpair, nargs='+', + help='space-separated list of layer names to dissect' + + ', in the form layername[:reportedname]') + parser.add_argument('--segments', type=str, default='dataset/broden', + help='directory containing segmentation dataset') + parser.add_argument('--segmenter', type=str, default=None, + help='constructor for asegmenter class') + parser.add_argument('--download', action='store_true', default=False, + help='downloads Broden dataset if needed') + parser.add_argument('--imagedir', type=str, default=None, + help='directory containing image-only dataset') + parser.add_argument('--imgsize', type=intpair, default=(227, 227), + help='input image size to use') + parser.add_argument('--netname', type=str, default=None, + help='name for network in generated reports') + parser.add_argument('--meta', type=str, nargs='+', + help='json files of metadata to add to report') + parser.add_argument('--merge', type=str, + help='json file of unit data to merge in report') + parser.add_argument('--examples', type=int, default=20, + help='number of image examples per unit') + parser.add_argument('--size', type=int, default=10000, + help='dataset subset size to use') + parser.add_argument('--batch_size', type=int, default=100, + help='batch size for forward pass') + parser.add_argument('--num_workers', type=int, default=24, + help='number of DataLoader workers') + parser.add_argument('--quantile_threshold', type=strfloat, default=None, + choices=[FloatRange(0.0, 1.0), 'iqr'], + help='quantile to use for masks') + parser.add_argument('--no-labels', action='store_true', default=False, + help='disables labeling of units') + parser.add_argument('--maxiou', action='store_true', default=False, + help='enables maxiou calculation') + parser.add_argument('--covariance', action='store_true', default=False, + help='enables covariance calculation') + parser.add_argument('--rank_all_labels', action='store_true', default=False, + help='include low-information labels in rankings') + parser.add_argument('--no-images', action='store_true', default=False, + help='disables generation of unit images') + parser.add_argument('--no-report', action='store_true', default=False, + help='disables generation report summary') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA usage') + parser.add_argument('--gen', action='store_true', default=False, + help='test a generator model (e.g., a GAN)') + parser.add_argument('--gan', action='store_true', default=False, + help='synonym for --gen') + parser.add_argument('--perturbation', default=None, + help='filename of perturbation attack to apply') + parser.add_argument('--add_scale_offset', action='store_true', default=None, + help='offsets masks according to stride and padding') + parser.add_argument('--quiet', action='store_true', default=False, + help='silences console output') + if len(sys.argv) == 1: + parser.print_usage(sys.stderr) + sys.exit(1) + args = parser.parse_args() + args.images = not args.no_images + args.report = not args.no_report + args.labels = not args.no_labels + if args.gan: + args.gen = args.gan + + # Set up console output + verbose_progress(not args.quiet) + + # Exit right away if job is already done or being done. + if args.outdir is not None: + exit_if_job_done(args.outdir) + + # Speed up pytorch + torch.backends.cudnn.benchmark = True + + # Special case: download flag without model to test. + if args.model is None and args.download: + from netdissect.broden import ensure_broden_downloaded + for resolution in [224, 227, 384]: + ensure_broden_downloaded(args.segments, resolution, 1) + from netdissect.segmenter import ensure_upp_segmenter_downloaded + ensure_upp_segmenter_downloaded('dataset/segmodel') + sys.exit(0) + + # Help if broden is not present + if not args.gen and not args.imagedir and not os.path.isdir(args.segments): + print_progress('Segmentation dataset not found at %s.' % args.segments) + print_progress('Specify dataset directory using --segments [DIR]') + print_progress('To download Broden, run: netdissect --download') + sys.exit(1) + + # Default segmenter class + if args.gen and args.segmenter is None: + args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" + + "segsizes=[256], segdiv='quad')") + + # Default threshold + if args.quantile_threshold is None: + if args.gen: + args.quantile_threshold = 'iqr' + else: + args.quantile_threshold = 0.005 + + # Set up CUDA + args.cuda = not args.no_cuda and torch.cuda.is_available() + if args.cuda: + torch.backends.cudnn.benchmark = True + + # Construct the network with specified layers instrumented + if args.model is None: + print_progress('No model specified') + sys.exit(1) + model = create_instrumented_model(args) + + # Update any metadata from files, if any + meta = getattr(model, 'meta', {}) + if args.meta: + for mfilename in args.meta: + with open(mfilename) as f: + meta.update(json.load(f)) + + # Load any merge data from files + mergedata = None + if args.merge: + with open(args.merge) as f: + mergedata = json.load(f) + + # Set up the output directory, verify write access + if args.outdir is None: + args.outdir = os.path.join('dissect', type(model).__name__) + exit_if_job_done(args.outdir) + print_progress('Writing output into %s.' % args.outdir) + os.makedirs(args.outdir, exist_ok=True) + train_dataset = None + + if not args.gen: + # Load dataset for classifier case. + # Load perturbation + perturbation = numpy.load(args.perturbation + ) if args.perturbation else None + segrunner = None + + # Load broden dataset + if args.imagedir is not None: + dataset = try_to_load_images(args.imagedir, args.imgsize, + perturbation, args.size) + segrunner = ImageOnlySegRunner(dataset) + else: + dataset = try_to_load_broden(args.segments, args.imgsize, 1, + perturbation, args.download, args.size) + if dataset is None: + dataset = try_to_load_multiseg(args.segments, args.imgsize, + perturbation, args.size) + if dataset is None: + print_progress('No segmentation dataset found in %s', + args.segments) + print_progress('use --download to download Broden.') + sys.exit(1) + else: + # For segmenter case the dataset is just a random z + dataset = z_dataset_for_model(model, args.size) + train_dataset = z_dataset_for_model(model, args.size, seed=2) + segrunner = GeneratorSegRunner(autoimport_eval(args.segmenter)) + + # Run dissect + dissect(args.outdir, model, dataset, + train_dataset=train_dataset, + segrunner=segrunner, + examples_per_unit=args.examples, + netname=args.netname, + quantile_threshold=args.quantile_threshold, + meta=meta, + merge=mergedata, + make_images=args.images, + make_labels=args.labels, + make_maxiou=args.maxiou, + make_covariance=args.covariance, + make_report=args.report, + make_row_images=args.images, + make_single_images=True, + rank_all_labels=args.rank_all_labels, + batch_size=args.batch_size, + num_workers=args.num_workers, + settings=vars(args)) + + # Mark the directory so that it's not done again. + mark_job_done(args.outdir) + +class AddPerturbation(object): + def __init__(self, perturbation): + self.perturbation = perturbation + + def __call__(self, pic): + if self.perturbation is None: + return pic + # Convert to a numpy float32 array + npyimg = numpy.array(pic, numpy.uint8, copy=False + ).astype(numpy.float32) + # Center the perturbation + oy, ox = ((self.perturbation.shape[d] - npyimg.shape[d]) // 2 + for d in [0, 1]) + npyimg += self.perturbation[ + oy:oy+npyimg.shape[0], ox:ox+npyimg.shape[1]] + # Pytorch conventions: as a float it should be [0..1] + npyimg.clip(0, 255, npyimg) + return npyimg / 255.0 + +def test_dissection(): + verbose_progress(True) + from torchvision.models import alexnet + from torchvision import transforms + model = InstrumentedModel(alexnet(pretrained=True)) + model.eval() + # Load an alexnet + model.retain_layers([ + ('features.0', 'conv1'), + ('features.3', 'conv2'), + ('features.6', 'conv3'), + ('features.8', 'conv4'), + ('features.10', 'conv5') ]) + # load broden dataset + bds = BrodenDataset('dataset/broden', + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), + size=100) + # run dissect + dissect('dissect/test', model, bds, + examples_per_unit=10) + +def try_to_load_images(directory, imgsize, perturbation, size): + # Load plain image dataset + # TODO: allow other normalizations. + return ParallelImageFolders( + [directory], + transform=transforms.Compose([ + transforms.Resize(imgsize), + AddPerturbation(perturbation), + transforms.ToTensor(), + transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), + size=size) + +def try_to_load_broden(directory, imgsize, broden_version, perturbation, + download, size): + # Load broden dataset + ds_resolution = (224 if max(imgsize) <= 224 else + 227 if max(imgsize) <= 227 else 384) + if not os.path.isfile(os.path.join(directory, + 'broden%d_%d' % (broden_version, ds_resolution), 'index.csv')): + return None + return BrodenDataset(directory, + resolution=ds_resolution, + download=download, + broden_version=broden_version, + transform=transforms.Compose([ + transforms.Resize(imgsize), + AddPerturbation(perturbation), + transforms.ToTensor(), + transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), + size=size) + +def try_to_load_multiseg(directory, imgsize, perturbation, size): + if not os.path.isfile(os.path.join(directory, 'labelnames.json')): + return None + minsize = min(imgsize) if hasattr(imgsize, '__iter__') else imgsize + return MultiSegmentDataset(directory, + transform=(transforms.Compose([ + transforms.Resize(minsize), + transforms.CenterCrop(imgsize), + AddPerturbation(perturbation), + transforms.ToTensor(), + transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), + transforms.Compose([ + transforms.Resize(minsize, interpolation=PIL.Image.NEAREST), + transforms.CenterCrop(imgsize)])), + size=size) + +def add_scale_offset_info(model, layer_names): + ''' + Creates a 'scale_offset' property on the model which guesses + how to offset the featuremap, in cases where the convolutional + padding does not exacly correspond to keeping featuremap pixels + centered on the downsampled regions of the input. This mainly + shows up in AlexNet: ResNet and VGG pad convolutions to keep + them centered and do not need this. + ''' + model.scale_offset = {} + seen = set() + sequence = [] + aka_map = {} + for name in layer_names: + aka = name + if not isinstance(aka, str): + name, aka = name + aka_map[name] = aka + for name, layer in model.named_modules(): + sequence.append(layer) + if name in aka_map: + seen.add(name) + aka = aka_map[name] + model.scale_offset[aka] = sequence_scale_offset(sequence) + for name in aka_map: + assert name in seen, ('Layer %s not found' % name) + +def dilation_scale_offset(dilations): + '''Composes a list of (k, s, p) into a single total scale and offset.''' + if len(dilations) == 0: + return (1, 0) + scale, offset = dilation_scale_offset(dilations[1:]) + kernel, stride, padding = dilations[0] + scale *= stride + offset *= stride + offset += (kernel - 1) / 2.0 - padding + return scale, offset + +def dilations(modulelist): + '''Converts a list of modules to (kernel_size, stride, padding)''' + result = [] + for module in modulelist: + settings = tuple(getattr(module, n, d) + for n, d in (('kernel_size', 1), ('stride', 1), ('padding', 0))) + settings = (((s, s) if not isinstance(s, tuple) else s) + for s in settings) + if settings != ((1, 1), (1, 1), (0, 0)): + result.append(zip(*settings)) + return zip(*result) + +def sequence_scale_offset(modulelist): + '''Returns (yscale, yoffset), (xscale, xoffset) given a list of modules''' + return tuple(dilation_scale_offset(d) for d in dilations(modulelist)) + + +def strfloat(s): + try: + return float(s) + except: + return s + +class FloatRange(object): + def __init__(self, start, end): + self.start = start + self.end = end + def __eq__(self, other): + return isinstance(other, float) and self.start <= other <= self.end + def __repr__(self): + return '[%g-%g]' % (self.start, self.end) + +# Many models use this normalization. +IMAGE_MEAN = [0.485, 0.456, 0.406] +IMAGE_STDEV = [0.229, 0.224, 0.225] + +if __name__ == '__main__': + main() diff --git a/netdissect/aceoptimize.py b/netdissect/aceoptimize.py new file mode 100644 index 0000000000000000000000000000000000000000..46ac0620073a0c26e9ead14b20db57c586ce15aa --- /dev/null +++ b/netdissect/aceoptimize.py @@ -0,0 +1,934 @@ +# Instantiate the segmenter gadget. +# Instantiate the GAN to optimize over +# Instrument the GAN for editing and optimization. +# Read quantile stats to learn 99.9th percentile for each unit, +# and also the 0.01th percentile. +# Read the median activation conditioned on door presence. + +import os, sys, numpy, torch, argparse, skimage, json, shutil +from PIL import Image +from torch.utils.data import TensorDataset +from matplotlib.figure import Figure +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +import matplotlib.gridspec as gridspec +from scipy.ndimage.morphology import binary_dilation + +import netdissect.zdataset +import netdissect.nethook +from netdissect.dissection import safe_dir_name +from netdissect.progress import verbose_progress, default_progress +from netdissect.progress import print_progress, desc_progress, post_progress +from netdissect.easydict import EasyDict +from netdissect.workerpool import WorkerPool, WorkerBase +from netdissect.runningstats import RunningQuantile +from netdissect.pidfile import pidfile_taken +from netdissect.modelconfig import create_instrumented_model +from netdissect.autoeval import autoimport_eval + +def main(): + parser = argparse.ArgumentParser(description='ACE optimization utility', + prog='python -m netdissect.aceoptimize') + parser.add_argument('--model', type=str, default=None, + help='constructor for the model to test') + parser.add_argument('--pthfile', type=str, default=None, + help='filename of .pth file for the model') + parser.add_argument('--segmenter', type=str, default=None, + help='constructor for asegmenter class') + parser.add_argument('--classname', type=str, default=None, + help='intervention classname') + parser.add_argument('--layer', type=str, default='layer4', + help='layer name') + parser.add_argument('--search_size', type=int, default=10000, + help='size of search for finding training locations') + parser.add_argument('--train_size', type=int, default=1000, + help='size of training set') + parser.add_argument('--eval_size', type=int, default=200, + help='size of eval set') + parser.add_argument('--inference_batch_size', type=int, default=10, + help='forward pass batch size') + parser.add_argument('--train_batch_size', type=int, default=2, + help='backprop pass batch size') + parser.add_argument('--train_update_freq', type=int, default=10, + help='number of batches for each training update') + parser.add_argument('--train_epochs', type=int, default=10, + help='number of epochs of training') + parser.add_argument('--l2_lambda', type=float, default=0.005, + help='l2 regularizer hyperparameter') + parser.add_argument('--eval_only', action='store_true', default=False, + help='reruns eval only on trained snapshots') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA usage') + parser.add_argument('--no-cache', action='store_true', default=False, + help='disables reading of cache') + parser.add_argument('--outdir', type=str, default=None, + help='dissection directory') + parser.add_argument('--variant', type=str, default=None, + help='experiment variant') + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + torch.backends.cudnn.benchmark = True + + run_command(args) + +def run_command(args): + verbose_progress(True) + progress = default_progress() + classname = args.classname # 'door' + layer = args.layer # 'layer4' + num_eval_units = 20 + + assert os.path.isfile(os.path.join(args.outdir, 'dissect.json')), ( + "Should be a dissection directory") + + if args.variant is None: + args.variant = 'ace' + + if args.l2_lambda != 0.005: + args.variant = '%s_reg%g' % (args.variant, args.l2_lambda) + + cachedir = os.path.join(args.outdir, safe_dir_name(layer), args.variant, + classname) + + if pidfile_taken(os.path.join(cachedir, 'lock.pid'), True): + sys.exit(0) + + # Take defaults for model constructor etc from dissect.json settings. + with open(os.path.join(args.outdir, 'dissect.json')) as f: + dissection = EasyDict(json.load(f)) + if args.model is None: + args.model = dissection.settings.model + if args.pthfile is None: + args.pthfile = dissection.settings.pthfile + if args.segmenter is None: + args.segmenter = dissection.settings.segmenter + # Default segmenter class + if args.segmenter is None: + args.segmenter = ("netdissect.segmenter.UnifiedParsingSegmenter(" + + "segsizes=[256], segdiv='quad')") + + if (not args.no_cache and + os.path.isfile(os.path.join(cachedir, 'snapshots', 'epoch-%d.npy' % ( + args.train_epochs - 1))) and + os.path.isfile(os.path.join(cachedir, 'report.json'))): + print('%s already done' % cachedir) + sys.exit(0) + + os.makedirs(cachedir, exist_ok=True) + + # Instantiate generator + model = create_instrumented_model(args, gen=True, edit=True, + layers=[args.layer]) + if model is None: + print('No model specified') + sys.exit(1) + # Instantiate segmenter + segmenter = autoimport_eval(args.segmenter) + labelnames, catname = segmenter.get_label_and_category_names() + classnum = [i for i, (n, c) in enumerate(labelnames) if n == classname][0] + num_classes = len(labelnames) + with open(os.path.join(cachedir, 'labelnames.json'), 'w') as f: + json.dump(labelnames, f, indent=1) + + # Sample sets for training. + full_sample = netdissect.zdataset.z_sample_for_model(model, + args.search_size, seed=10) + second_sample = netdissect.zdataset.z_sample_for_model(model, + args.search_size, seed=11) + # Load any cached data. + cache_filename = os.path.join(cachedir, 'corpus.npz') + corpus = EasyDict() + try: + if not args.no_cache: + corpus = EasyDict({k: torch.from_numpy(v) + for k, v in numpy.load(cache_filename).items()}) + except: + pass + + # The steps for the computation. + compute_present_locations(args, corpus, cache_filename, + model, segmenter, classnum, full_sample) + compute_mean_present_features(args, corpus, cache_filename, model) + compute_feature_quantiles(args, corpus, cache_filename, model, full_sample) + compute_candidate_locations(args, corpus, cache_filename, model, segmenter, + classnum, second_sample) + # visualize_training_locations(args, corpus, cachedir, model) + init_ablation = initial_ablation(args, args.outdir) + scores = train_ablation(args, corpus, cache_filename, + model, segmenter, classnum, init_ablation) + summarize_scores(args, corpus, cachedir, layer, classname, + args.variant, scores) + if args.variant == 'ace': + add_ace_ranking_to_dissection(args.outdir, layer, classname, scores) + # TODO: do some evaluation. + +class SaveImageWorker(WorkerBase): + def work(self, data, filename): + Image.fromarray(data).save(filename, optimize=True, quality=80) + +def plot_heatmap(output_filename, data, size=256): + fig = Figure(figsize=(1, 1), dpi=size) + canvas = FigureCanvas(fig) + gs = gridspec.GridSpec(1, 1, left=0.0, right=1.0, bottom=0.0, top=1.0) + ax = fig.add_subplot(gs[0]) + ax.set_axis_off() + ax.imshow(data, cmap='hot', aspect='equal', interpolation='nearest', + vmin=-1, vmax=1) + canvas.print_figure(output_filename, format='png') + + +def draw_heatmap(output_filename, data, size=256): + fig = Figure(figsize=(1, 1), dpi=size) + canvas = FigureCanvas(fig) + gs = gridspec.GridSpec(1, 1, left=0.0, right=1.0, bottom=0.0, top=1.0) + ax = fig.add_subplot(gs[0]) + ax.set_axis_off() + ax.imshow(data, cmap='hot', aspect='equal', interpolation='nearest', + vmin=-1, vmax=1) + canvas.draw() # draw the canvas, cache the renderer + image = numpy.fromstring(canvas.tostring_rgb(), dtype='uint8').reshape( + (size, size, 3)) + return image + +def compute_present_locations(args, corpus, cache_filename, + model, segmenter, classnum, full_sample): + # Phase 1. Identify a set of locations where there are doorways. + # Segment the image and find featuremap pixels that maximize the number + # of doorway pixels under the featuremap pixel. + if all(k in corpus for k in ['present_indices', + 'object_present_sample', 'object_present_location', + 'object_location_popularity', 'weighted_mean_present_feature']): + return + progress = default_progress() + feature_shape = model.feature_shape[args.layer][2:] + num_locations = numpy.prod(feature_shape).item() + num_units = model.feature_shape[args.layer][1] + with torch.no_grad(): + weighted_feature_sum = torch.zeros(num_units).cuda() + object_presence_scores = [] + for [zbatch] in progress( + torch.utils.data.DataLoader(TensorDataset(full_sample), + batch_size=args.inference_batch_size, num_workers=10, + pin_memory=True), + desc="Object pool"): + zbatch = zbatch.cuda() + tensor_image = model(zbatch) + segmented_image = segmenter.segment_batch(tensor_image, + downsample=2) + mask = (segmented_image == classnum).max(1)[0] + score = torch.nn.functional.adaptive_avg_pool2d( + mask.float(), feature_shape) + object_presence_scores.append(score.cpu()) + feat = model.retained_layer(args.layer) + weighted_feature_sum += (feat * score[:,None,:,:]).view( + feat.shape[0],feat.shape[1], -1).sum(2).sum(0) + object_presence_at_feature = torch.cat(object_presence_scores) + object_presence_at_image, object_location_in_image = ( + object_presence_at_feature.view(args.search_size, -1).max(1)) + best_presence_scores, best_presence_images = torch.sort( + -object_presence_at_image) + all_present_indices = torch.sort( + best_presence_images[:(args.train_size+args.eval_size)])[0] + corpus.present_indices = all_present_indices[:args.train_size] + corpus.object_present_sample = full_sample[corpus.present_indices] + corpus.object_present_location = object_location_in_image[ + corpus.present_indices] + corpus.object_location_popularity = torch.bincount( + corpus.object_present_location, + minlength=num_locations) + corpus.weighted_mean_present_feature = (weighted_feature_sum.cpu() / ( + 1e-20 + object_presence_at_feature.view(-1).sum())) + corpus.eval_present_indices = all_present_indices[-args.eval_size:] + corpus.eval_present_sample = full_sample[corpus.eval_present_indices] + corpus.eval_present_location = object_location_in_image[ + corpus.eval_present_indices] + + if cache_filename: + numpy.savez(cache_filename, **corpus) + +def compute_mean_present_features(args, corpus, cache_filename, model): + # Phase 1.5. Figure mean activations for every channel where there + # is a doorway. + if all(k in corpus for k in ['mean_present_feature']): + return + progress = default_progress() + with torch.no_grad(): + total_present_feature = 0 + for [zbatch, featloc] in progress( + torch.utils.data.DataLoader(TensorDataset( + corpus.object_present_sample, + corpus.object_present_location), + batch_size=args.inference_batch_size, num_workers=10, + pin_memory=True), + desc="Mean activations"): + zbatch = zbatch.cuda() + featloc = featloc.cuda() + tensor_image = model(zbatch) + feat = model.retained_layer(args.layer) + flatfeat = feat.view(feat.shape[0], feat.shape[1], -1) + sum_feature_at_obj = flatfeat[ + torch.arange(feat.shape[0]).to(feat.device), :, featloc + ].sum(0) + total_present_feature = total_present_feature + sum_feature_at_obj + corpus.mean_present_feature = (total_present_feature / len( + corpus.object_present_sample)).cpu() + if cache_filename: + numpy.savez(cache_filename, **corpus) + +def compute_feature_quantiles(args, corpus, cache_filename, model, full_sample): + # Phase 1.6. Figure the 99% and 99.9%ile of every feature. + if all(k in corpus for k in ['feature_99', 'feature_999']): + return + progress = default_progress() + with torch.no_grad(): + rq = RunningQuantile(resolution=10000) # 10x what's needed. + for [zbatch] in progress( + torch.utils.data.DataLoader(TensorDataset(full_sample), + batch_size=args.inference_batch_size, num_workers=10, + pin_memory=True), + desc="Calculating 0.999 quantile"): + zbatch = zbatch.cuda() + tensor_image = model(zbatch) + feat = model.retained_layer(args.layer) + rq.add(feat.permute(0, 2, 3, 1 + ).contiguous().view(-1, feat.shape[1])) + result = rq.quantiles([0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999]) + corpus.feature_001 = result[:, 0].cpu() + corpus.feature_01 = result[:, 1].cpu() + corpus.feature_10 = result[:, 2].cpu() + corpus.feature_50 = result[:, 3].cpu() + corpus.feature_90 = result[:, 4].cpu() + corpus.feature_99 = result[:, 5].cpu() + corpus.feature_999 = result[:, 6].cpu() + numpy.savez(cache_filename, **corpus) + +def compute_candidate_locations(args, corpus, cache_filename, model, + segmenter, classnum, second_sample): + # Phase 2. Identify a set of candidate locations for doorways. + # Place the median doorway activation in every location of an image + # and identify where it can go that doorway pixels increase. + if all(k in corpus for k in ['candidate_indices', + 'candidate_sample', 'candidate_score', + 'candidate_location', 'object_score_at_candidate', + 'candidate_location_popularity']): + return + progress = default_progress() + feature_shape = model.feature_shape[args.layer][2:] + num_locations = numpy.prod(feature_shape).item() + with torch.no_grad(): + # Simplify - just treat all locations as possible + possible_locations = numpy.arange(num_locations) + + # Speed up search for locations, by weighting probed locations + # according to observed distribution. + location_weights = (corpus.object_location_popularity).double() + location_weights += (location_weights.mean()) / 10.0 + location_weights = location_weights / location_weights.sum() + + candidate_scores = [] + object_scores = [] + prng = numpy.random.RandomState(1) + for [zbatch] in progress( + torch.utils.data.DataLoader(TensorDataset(second_sample), + batch_size=args.inference_batch_size, num_workers=10, + pin_memory=True), + desc="Candidate pool"): + batch_scores = torch.zeros((len(zbatch),) + feature_shape).cuda() + flat_batch_scores = batch_scores.view(len(zbatch), -1) + zbatch = zbatch.cuda() + tensor_image = model(zbatch) + segmented_image = segmenter.segment_batch(tensor_image, + downsample=2) + mask = (segmented_image == classnum).max(1)[0] + object_score = torch.nn.functional.adaptive_avg_pool2d( + mask.float(), feature_shape) + baseline_presence = mask.float().view(mask.shape[0], -1).sum(1) + + edit_mask = torch.zeros((1, 1) + feature_shape).cuda() + if '_tcm' in args.variant: + # variant: top-conditional-mean + replace_vec = (corpus.mean_present_feature + [None,:,None,None].cuda()) + else: # default: weighted mean + replace_vec = (corpus.weighted_mean_present_feature + [None,:,None,None].cuda()) + # Sample 10 random locations to examine. + for loc in prng.choice(possible_locations, replace=False, + p=location_weights, size=5): + edit_mask.zero_() + edit_mask.view(-1)[loc] = 1 + model.edit_layer(args.layer, + ablation=edit_mask, replacement=replace_vec) + tensor_image = model(zbatch) + segmented_image = segmenter.segment_batch(tensor_image, + downsample=2) + mask = (segmented_image == classnum).max(1)[0] + modified_presence = mask.float().view( + mask.shape[0], -1).sum(1) + flat_batch_scores[:,loc] = ( + modified_presence - baseline_presence) + candidate_scores.append(batch_scores.cpu()) + object_scores.append(object_score.cpu()) + + object_scores = torch.cat(object_scores) + candidate_scores = torch.cat(candidate_scores) + # Eliminate candidates where the object is present. + candidate_scores = candidate_scores * (object_scores == 0).float() + candidate_score_at_image, candidate_location_in_image = ( + candidate_scores.view(args.search_size, -1).max(1)) + best_candidate_scores, best_candidate_images = torch.sort( + -candidate_score_at_image) + all_candidate_indices = torch.sort( + best_candidate_images[:(args.train_size+args.eval_size)])[0] + corpus.candidate_indices = all_candidate_indices[:args.train_size] + corpus.candidate_sample = second_sample[corpus.candidate_indices] + corpus.candidate_location = candidate_location_in_image[ + corpus.candidate_indices] + corpus.candidate_score = candidate_score_at_image[ + corpus.candidate_indices] + corpus.object_score_at_candidate = object_scores.view( + len(object_scores), -1)[ + corpus.candidate_indices, corpus.candidate_location] + corpus.candidate_location_popularity = torch.bincount( + corpus.candidate_location, + minlength=num_locations) + corpus.eval_candidate_indices = all_candidate_indices[ + -args.eval_size:] + corpus.eval_candidate_sample = second_sample[ + corpus.eval_candidate_indices] + corpus.eval_candidate_location = candidate_location_in_image[ + corpus.eval_candidate_indices] + numpy.savez(cache_filename, **corpus) + +def visualize_training_locations(args, corpus, cachedir, model): + # Phase 2.5 Create visualizations of the corpus images. + progress = default_progress() + feature_shape = model.feature_shape[args.layer][2:] + num_locations = numpy.prod(feature_shape).item() + with torch.no_grad(): + imagedir = os.path.join(cachedir, 'image') + os.makedirs(imagedir, exist_ok=True) + image_saver = WorkerPool(SaveImageWorker) + for group, group_sample, group_location, group_indices in [ + ('present', + corpus.object_present_sample, + corpus.object_present_location, + corpus.present_indices), + ('candidate', + corpus.candidate_sample, + corpus.candidate_location, + corpus.candidate_indices)]: + for [zbatch, featloc, indices] in progress( + torch.utils.data.DataLoader(TensorDataset( + group_sample, group_location, group_indices), + batch_size=args.inference_batch_size, num_workers=10, + pin_memory=True), + desc="Visualize %s" % group): + zbatch = zbatch.cuda() + tensor_image = model(zbatch) + feature_mask = torch.zeros((len(zbatch), 1) + feature_shape) + feature_mask.view(len(zbatch), -1).scatter_( + 1, featloc[:,None], 1) + feature_mask = torch.nn.functional.adaptive_max_pool2d( + feature_mask.float(), tensor_image.shape[-2:]).cuda() + yellow = torch.Tensor([1.0, 1.0, -1.0] + )[None, :, None, None].cuda() + tensor_image = tensor_image * (1 - 0.5 * feature_mask) + ( + 0.5 * feature_mask * yellow) + byte_image = (((tensor_image+1)/2)*255).clamp(0, 255).byte() + numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() + for i, index in enumerate(indices): + image_saver.add(numpy_image[i], os.path.join(imagedir, + '%s_%d.jpg' % (group, index))) + image_saver.join() + +def scale_summary(scale, lownums, highnums): + value, order = (-(scale.detach())).cpu().sort(0) + lowsum = ' '.join('%d: %.3g' % (o.item(), -v.item()) + for v, o in zip(value[:lownums], order[:lownums])) + highsum = ' '.join('%d: %.3g' % (o.item(), -v.item()) + for v, o in zip(value[-highnums:], order[-highnums:])) + return lowsum + ' ... ' + highsum + +# Phase 3. Given those two sets, now optimize a such that: +# Door pred lost if we take 0 * a at a candidate (1) +# Door pred gained If we take 99.9th activation * a at a candiate (1) +# + +# ADE_au = E | on - E | off) +# = cand-frac E_cand | on + nocand-frac E_cand | on +# - door-frac E_door | off + nodoor-frac E_nodoor | off +# approx = cand-frac E_cand | on - door-frac E_door | off + K +# Each batch has both types, and minimizes +# door-frac sum(s_c) when pixel off - cand-frac sum(s_c) when pixel on + +def initial_ablation(args, dissectdir): + # Load initialization from dissection, based on iou scores. + with open(os.path.join(dissectdir, 'dissect.json')) as f: + dissection = EasyDict(json.load(f)) + lrec = [l for l in dissection.layers if l.layer == args.layer][0] + rrec = [r for r in lrec.rankings if r.name == '%s-iou' % args.classname + ][0] + init_scores = -torch.tensor(rrec.score) + return init_scores / init_scores.max() + +def ace_loss(segmenter, classnum, model, layer, high_replacement, ablation, + pbatch, ploc, cbatch, cloc, run_backward=False, + discrete_pixels=False, + discrete_units=False, + mixed_units=False, + ablation_only=False, + fullimage_measurement=False, + fullimage_ablation=False, + ): + feature_shape = model.feature_shape[layer][2:] + if discrete_units: # discretize ablation to the top N units + assert discrete_units > 0 + d = torch.zeros_like(ablation) + top_units = torch.topk(ablation.view(-1), discrete_units)[1] + if mixed_units: + d.view(-1)[top_units] = ablation.view(-1)[top_units] + else: + d.view(-1)[top_units] = 1 + ablation = d + # First, ablate a sample of locations with positive presence + # and see how much the presence is reduced. + p_mask = torch.zeros((len(pbatch), 1) + feature_shape) + if fullimage_ablation: + p_mask[...] = 1 + else: + p_mask.view(len(pbatch), -1).scatter_(1, ploc[:,None], 1) + p_mask = p_mask.cuda() + a_p_mask = (ablation * p_mask) + model.edit_layer(layer, ablation=a_p_mask, replacement=None) + tensor_images = model(pbatch.cuda()) + assert model._ablation[layer] is a_p_mask + erase_effect, erased_mask = segmenter.predict_single_class( + tensor_images, classnum, downsample=2) + if discrete_pixels: # pixel loss: use mask instead of pred + erase_effect = erased_mask.float() + erase_downsampled = torch.nn.functional.adaptive_avg_pool2d( + erase_effect[:,None,:,:], feature_shape)[:,0,:,:] + if fullimage_measurement: + erase_loss = erase_downsampled.sum() + else: + erase_at_loc = erase_downsampled.view(len(erase_downsampled), -1 + )[torch.arange(len(erase_downsampled)), ploc] + erase_loss = erase_at_loc.sum() + if run_backward: + erase_loss.backward() + if ablation_only: + return erase_loss + # Second, activate a sample of locations that are candidates for + # insertion and see how much the presence is increased. + c_mask = torch.zeros((len(cbatch), 1) + feature_shape) + c_mask.view(len(cbatch), -1).scatter_(1, cloc[:,None], 1) + c_mask = c_mask.cuda() + a_c_mask = (ablation * c_mask) + model.edit_layer(layer, ablation=a_c_mask, replacement=high_replacement) + tensor_images = model(cbatch.cuda()) + assert model._ablation[layer] is a_c_mask + add_effect, added_mask = segmenter.predict_single_class( + tensor_images, classnum, downsample=2) + if discrete_pixels: # pixel loss: use mask instead of pred + add_effect = added_mask.float() + add_effect = -add_effect + add_downsampled = torch.nn.functional.adaptive_avg_pool2d( + add_effect[:,None,:,:], feature_shape)[:,0,:,:] + if fullimage_measurement: + add_loss = add_downsampled.mean() + else: + add_at_loc = add_downsampled.view(len(add_downsampled), -1 + )[torch.arange(len(add_downsampled)), ploc] + add_loss = add_at_loc.sum() + if run_backward: + add_loss.backward() + return erase_loss + add_loss + +def train_ablation(args, corpus, cachefile, model, segmenter, classnum, + initial_ablation=None): + progress = default_progress() + cachedir = os.path.dirname(cachefile) + snapdir = os.path.join(cachedir, 'snapshots') + os.makedirs(snapdir, exist_ok=True) + + # high_replacement = corpus.feature_99[None,:,None,None].cuda() + if '_h99' in args.variant: + high_replacement = corpus.feature_99[None,:,None,None].cuda() + elif '_tcm' in args.variant: + # variant: top-conditional-mean + high_replacement = ( + corpus.mean_present_feature[None,:,None,None].cuda()) + else: # default: weighted mean + high_replacement = ( + corpus.weighted_mean_present_feature[None,:,None,None].cuda()) + fullimage_measurement = False + ablation_only = False + fullimage_ablation = False + if '_fim' in args.variant: + fullimage_measurement = True + elif '_fia' in args.variant: + fullimage_measurement = True + ablation_only = True + fullimage_ablation = True + high_replacement.requires_grad = False + for p in model.parameters(): + p.requires_grad = False + + ablation = torch.zeros(high_replacement.shape).cuda() + if initial_ablation is not None: + ablation.view(-1)[...] = initial_ablation + ablation.requires_grad = True + optimizer = torch.optim.Adam([ablation], lr=0.01) + start_epoch = 0 + epoch = 0 + + def eval_loss_and_reg(): + discrete_experiments = dict( + # dpixel=dict(discrete_pixels=True), + # dunits20=dict(discrete_units=20), + # dumix20=dict(discrete_units=20, mixed_units=True), + # dunits10=dict(discrete_units=10), + # abonly=dict(ablation_only=True), + # fimabl=dict(ablation_only=True, + # fullimage_ablation=True, + # fullimage_measurement=True), + dboth20=dict(discrete_units=20, discrete_pixels=True), + # dbothm20=dict(discrete_units=20, mixed_units=True, + # discrete_pixels=True), + # abdisc20=dict(discrete_units=20, discrete_pixels=True, + # ablation_only=True), + # abdiscm20=dict(discrete_units=20, mixed_units=True, + # discrete_pixels=True, + # ablation_only=True), + # fimadp=dict(discrete_pixels=True, + # ablation_only=True, + # fullimage_ablation=True, + # fullimage_measurement=True), + # fimadu10=dict(discrete_units=10, + # ablation_only=True, + # fullimage_ablation=True, + # fullimage_measurement=True), + # fimadb10=dict(discrete_units=10, discrete_pixels=True, + # ablation_only=True, + # fullimage_ablation=True, + # fullimage_measurement=True), + fimadbm10=dict(discrete_units=10, mixed_units=True, + discrete_pixels=True, + ablation_only=True, + fullimage_ablation=True, + fullimage_measurement=True), + # fimadu20=dict(discrete_units=20, + # ablation_only=True, + # fullimage_ablation=True, + # fullimage_measurement=True), + # fimadb20=dict(discrete_units=20, discrete_pixels=True, + # ablation_only=True, + # fullimage_ablation=True, + # fullimage_measurement=True), + fimadbm20=dict(discrete_units=20, mixed_units=True, + discrete_pixels=True, + ablation_only=True, + fullimage_ablation=True, + fullimage_measurement=True) + ) + with torch.no_grad(): + total_loss = 0 + discrete_losses = {k: 0 for k in discrete_experiments} + for [pbatch, ploc, cbatch, cloc] in progress( + torch.utils.data.DataLoader(TensorDataset( + corpus.eval_present_sample, + corpus.eval_present_location, + corpus.eval_candidate_sample, + corpus.eval_candidate_location), + batch_size=args.inference_batch_size, num_workers=10, + shuffle=False, pin_memory=True), + desc="Eval"): + # First, put in zeros for the selected units. + # Loss is amount of remaining object. + total_loss = total_loss + ace_loss(segmenter, classnum, + model, args.layer, high_replacement, ablation, + pbatch, ploc, cbatch, cloc, run_backward=False, + ablation_only=ablation_only, + fullimage_measurement=fullimage_measurement) + for k, config in discrete_experiments.items(): + discrete_losses[k] = discrete_losses[k] + ace_loss( + segmenter, classnum, + model, args.layer, high_replacement, ablation, + pbatch, ploc, cbatch, cloc, run_backward=False, + **config) + avg_loss = (total_loss / args.eval_size).item() + avg_d_losses = {k: (d / args.eval_size).item() + for k, d in discrete_losses.items()} + regularizer = (args.l2_lambda * ablation.pow(2).sum()) + print_progress('Epoch %d Loss %g Regularizer %g' % + (epoch, avg_loss, regularizer)) + print_progress(' '.join('%s: %g' % (k, d) + for k, d in avg_d_losses.items())) + print_progress(scale_summary(ablation.view(-1), 10, 3)) + return avg_loss, regularizer, avg_d_losses + + if args.eval_only: + # For eval_only, just load each snapshot and re-run validation eval + # pass on each one. + for epoch in range(-1, args.train_epochs): + snapfile = os.path.join(snapdir, 'epoch-%d.pth' % epoch) + if not os.path.exists(snapfile): + data = {} + if epoch >= 0: + print('No epoch %d' % epoch) + continue + else: + data = torch.load(snapfile) + with torch.no_grad(): + ablation[...] = data['ablation'].to(ablation.device) + optimizer.load_state_dict(data['optimizer']) + avg_loss, regularizer, new_extra = eval_loss_and_reg() + # Keep old values, and update any new ones. + extra = {k: v for k, v in data.items() + if k not in ['ablation', 'optimizer', 'avg_loss']} + extra.update(new_extra) + torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), + avg_loss=avg_loss, **extra), + os.path.join(snapdir, 'epoch-%d.pth' % epoch)) + # Return loaded ablation. + return ablation.view(-1).detach().cpu().numpy() + + if not args.no_cache: + for start_epoch in reversed(range(args.train_epochs)): + snapfile = os.path.join(snapdir, 'epoch-%d.pth' % start_epoch) + if os.path.exists(snapfile): + data = torch.load(snapfile) + with torch.no_grad(): + ablation[...] = data['ablation'].to(ablation.device) + optimizer.load_state_dict(data['optimizer']) + start_epoch += 1 + break + + if start_epoch < args.train_epochs: + epoch = start_epoch - 1 + avg_loss, regularizer, extra = eval_loss_and_reg() + if epoch == -1: + torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), + avg_loss=avg_loss, **extra), + os.path.join(snapdir, 'epoch-%d.pth' % epoch)) + + update_size = args.train_update_freq * args.train_batch_size + for epoch in range(start_epoch, args.train_epochs): + candidate_shuffle = torch.randperm(len(corpus.candidate_sample)) + train_loss = 0 + for batch_num, [pbatch, ploc, cbatch, cloc] in enumerate(progress( + torch.utils.data.DataLoader(TensorDataset( + corpus.object_present_sample, + corpus.object_present_location, + corpus.candidate_sample[candidate_shuffle], + corpus.candidate_location[candidate_shuffle]), + batch_size=args.train_batch_size, num_workers=10, + shuffle=True, pin_memory=True), + desc="ACE opt epoch %d" % epoch)): + if batch_num % args.train_update_freq == 0: + optimizer.zero_grad() + # First, put in zeros for the selected units. Loss is amount + # of remaining object. + loss = ace_loss(segmenter, classnum, + model, args.layer, high_replacement, ablation, + pbatch, ploc, cbatch, cloc, run_backward=True, + ablation_only=ablation_only, + fullimage_measurement=fullimage_measurement) + with torch.no_grad(): + train_loss = train_loss + loss + if (batch_num + 1) % args.train_update_freq == 0: + # Third, add some L2 loss to encourage sparsity. + regularizer = (args.l2_lambda * update_size + * ablation.pow(2).sum()) + regularizer.backward() + optimizer.step() + with torch.no_grad(): + ablation.clamp_(0, 1) + post_progress(l=(train_loss/update_size).item(), + r=(regularizer/update_size).item()) + train_loss = 0 + + avg_loss, regularizer, extra = eval_loss_and_reg() + torch.save(dict(ablation=ablation, optimizer=optimizer.state_dict(), + avg_loss=avg_loss, **extra), + os.path.join(snapdir, 'epoch-%d.pth' % epoch)) + numpy.save(os.path.join(snapdir, 'epoch-%d.npy' % epoch), + ablation.detach().cpu().numpy()) + + # The output of this phase is this set of scores. + return ablation.view(-1).detach().cpu().numpy() + + +def tensor_to_numpy_image_batch(tensor_image): + byte_image = (((tensor_image+1)/2)*255).clamp(0, 255).byte() + numpy_image = byte_image.permute(0, 2, 3, 1).cpu().numpy() + return numpy_image + +# Phase 4: evaluation of intervention + +def evaluate_ablation(args, model, segmenter, eval_sample, classnum, layer, + ordering): + total_bincount = 0 + data_size = 0 + progress = default_progress() + for l in model.ablation: + model.ablation[l] = None + feature_units = model.feature_shape[args.layer][1] + feature_shape = model.feature_shape[args.layer][2:] + repeats = len(ordering) + total_scores = torch.zeros(repeats + 1) + for i, batch in enumerate(progress(torch.utils.data.DataLoader( + TensorDataset(eval_sample), + batch_size=args.inference_batch_size, num_workers=10, + pin_memory=True), + desc="Evaluate interventions")): + tensor_image = model(zbatch) + segmented_image = segmenter.segment_batch(tensor_image, + downsample=2) + mask = (segmented_image == classnum).max(1)[0] + downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( + mask.float()[:,None,:,:], feature_shape)[:,0,:,:] + total_scores[0] += downsampled_seg.sum().cpu() + # Now we need to do an intervention for every location + # that had a nonzero downsampled_seg, if any. + interventions_needed = downsampled_seg.nonzero() + location_count = len(interventions_needed) + if location_count == 0: + continue + interventions_needed = interventions_needed.repeat(repeats, 1) + inter_z = batch[0][interventions_needed[:,0]].to(device) + inter_chan = torch.zeros(repeats, location_count, feature_units, + device=device) + for j, u in enumerate(ordering): + inter_chan[j:, :, u] = 1 + inter_chan = inter_chan.view(len(inter_z), feature_units) + inter_loc = interventions_needed[:,1:] + scores = torch.zeros(len(inter_z)) + batch_size = len(batch[0]) + for j in range(0, len(inter_z), batch_size): + ibz = inter_z[j:j+batch_size] + ibl = inter_loc[j:j+batch_size].t() + imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) + imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 + ibc = inter_chan[j:j+batch_size] + model.edit_layer(args.layer, ablation=( + imask.float()[:,None,:,:] * ibc[:,:,None,None])) + _, seg, _, _, _ = ( + recovery.recover_im_seg_bc_and_features( + [ibz], model)) + mask = (seg == classnum).max(1)[0] + downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( + mask.float()[:,None,:,:], feature_shape)[:,0,:,:] + scores[j:j+batch_size] = downsampled_iseg[ + (torch.arange(len(ibz)),) + tuple(ibl)] + scores = scores.view(repeats, location_count).sum(1) + total_scores[1:] += scores + return total_scores + +def evaluate_interventions(args, model, segmenter, eval_sample, + classnum, layer, units): + total_bincount = 0 + data_size = 0 + progress = default_progress() + for l in model.ablation: + model.ablation[l] = None + feature_units = model.feature_shape[args.layer][1] + feature_shape = model.feature_shape[args.layer][2:] + repeats = len(ordering) + total_scores = torch.zeros(repeats + 1) + for i, batch in enumerate(progress(torch.utils.data.DataLoader( + TensorDataset(eval_sample), + batch_size=args.inference_batch_size, num_workers=10, + pin_memory=True), + desc="Evaluate interventions")): + tensor_image = model(zbatch) + segmented_image = segmenter.segment_batch(tensor_image, + downsample=2) + mask = (segmented_image == classnum).max(1)[0] + downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( + mask.float()[:,None,:,:], feature_shape)[:,0,:,:] + total_scores[0] += downsampled_seg.sum().cpu() + # Now we need to do an intervention for every location + # that had a nonzero downsampled_seg, if any. + interventions_needed = downsampled_seg.nonzero() + location_count = len(interventions_needed) + if location_count == 0: + continue + interventions_needed = interventions_needed.repeat(repeats, 1) + inter_z = batch[0][interventions_needed[:,0]].to(device) + inter_chan = torch.zeros(repeats, location_count, feature_units, + device=device) + for j, u in enumerate(ordering): + inter_chan[j:, :, u] = 1 + inter_chan = inter_chan.view(len(inter_z), feature_units) + inter_loc = interventions_needed[:,1:] + scores = torch.zeros(len(inter_z)) + batch_size = len(batch[0]) + for j in range(0, len(inter_z), batch_size): + ibz = inter_z[j:j+batch_size] + ibl = inter_loc[j:j+batch_size].t() + imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) + imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 + ibc = inter_chan[j:j+batch_size] + model.ablation[args.layer] = ( + imask.float()[:,None,:,:] * ibc[:,:,None,None]) + _, seg, _, _, _ = ( + recovery.recover_im_seg_bc_and_features( + [ibz], model)) + mask = (seg == classnum).max(1)[0] + downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( + mask.float()[:,None,:,:], feature_shape)[:,0,:,:] + scores[j:j+batch_size] = downsampled_iseg[ + (torch.arange(len(ibz)),) + tuple(ibl)] + scores = scores.view(repeats, location_count).sum(1) + total_scores[1:] += scores + return total_scores + + +def add_ace_ranking_to_dissection(outdir, layer, classname, total_scores): + source_filename = os.path.join(outdir, 'dissect.json') + source_filename_bak = os.path.join(outdir, 'dissect.json.bak') + + # Back up the dissection (if not already backed up) before modifying + if not os.path.exists(source_filename_bak): + shutil.copy(source_filename, source_filename_bak) + + with open(source_filename) as f: + dissection = EasyDict(json.load(f)) + + ranking_name = '%s-ace' % classname + + # Remove any old ace ranking with the same name + lrec = [l for l in dissection.layers if l.layer == layer][0] + lrec.rankings = [r for r in lrec.rankings if r.name != ranking_name] + + # Now convert ace scores to rankings + new_rankings = [dict( + name=ranking_name, + score=(-total_scores).flatten().tolist(), + metric='ace')] + + # Prepend to list. + lrec.rankings[2:2] = new_rankings + + # Replace the old dissect.json in-place + with open(source_filename, 'w') as f: + json.dump(dissection, f, indent=1) + +def summarize_scores(args, corpus, cachedir, layer, classname, variant, scores): + target_filename = os.path.join(cachedir, 'summary.json') + + ranking_name = '%s-%s' % (classname, variant) + # Now convert ace scores to rankings + new_rankings = [dict( + name=ranking_name, + score=(-scores).flatten().tolist(), + metric=variant)] + result = dict(layers=[dict(layer=layer, rankings=new_rankings)]) + + # Replace the old dissect.json in-place + with open(target_filename, 'w') as f: + json.dump(result, f, indent=1) + +if __name__ == '__main__': + main() diff --git a/netdissect/aceplotablate.py b/netdissect/aceplotablate.py new file mode 100644 index 0000000000000000000000000000000000000000..585195eaf973760a7d78e4da6539c343049141de --- /dev/null +++ b/netdissect/aceplotablate.py @@ -0,0 +1,54 @@ +import os, sys, argparse, json, shutil +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator +import matplotlib + +def main(): + parser = argparse.ArgumentParser(description='ACE optimization utility', + prog='python -m netdissect.aceoptimize') + parser.add_argument('--classname', type=str, default=None, + help='intervention classname') + parser.add_argument('--layer', type=str, default='layer4', + help='layer name') + parser.add_argument('--outdir', type=str, default=None, + help='dissection directory') + parser.add_argument('--metric', type=str, default=None, + help='experiment variant') + args = parser.parse_args() + + if args.metric is None: + args.metric = 'ace' + + run_command(args) + +def run_command(args): + fig = Figure(figsize=(4.5,3.5)) + FigureCanvas(fig) + ax = fig.add_subplot(111) + for metric in [args.metric, 'iou']: + jsonname = os.path.join(args.outdir, args.layer, 'fullablation', + '%s-%s.json' % (args.classname, metric)) + with open(jsonname) as f: + summary = json.load(f) + baseline = summary['baseline'] + effects = summary['ablation_effects'][:26] + norm_effects = [0] + [1.0 - e / baseline for e in effects] + ax.plot(norm_effects, label= + 'Units by ACE' if 'ace' in metric else 'Top units by IoU') + ax.set_title('Effect of ablating units for %s' % (args.classname)) + ax.grid(True) + ax.legend() + ax.set_ylabel('Portion of %s pixels removed' % args.classname) + ax.set_xlabel('Number of units ablated') + ax.set_ylim(0, 1.0) + ax.set_xlim(0, 25) + fig.tight_layout() + dirname = os.path.join(args.outdir, args.layer, 'fullablation') + fig.savefig(os.path.join(dirname, 'effect-%s-%s.png' % + (args.classname, args.metric))) + fig.savefig(os.path.join(dirname, 'effect-%s-%s.pdf' % + (args.classname, args.metric))) + +if __name__ == '__main__': + main() diff --git a/netdissect/acesummarize.py b/netdissect/acesummarize.py new file mode 100644 index 0000000000000000000000000000000000000000..345129245b461f44ef58538f02a08c3684d33f31 --- /dev/null +++ b/netdissect/acesummarize.py @@ -0,0 +1,62 @@ +import os, sys, numpy, torch, argparse, skimage, json, shutil +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator +import matplotlib + +def main(): + parser = argparse.ArgumentParser(description='ACE optimization utility', + prog='python -m netdissect.aceoptimize') + parser.add_argument('--classname', type=str, default=None, + help='intervention classname') + parser.add_argument('--layer', type=str, default='layer4', + help='layer name') + parser.add_argument('--l2_lambda', type=float, nargs='+', + help='l2 regularizer hyperparameter') + parser.add_argument('--outdir', type=str, default=None, + help='dissection directory') + parser.add_argument('--variant', type=str, default=None, + help='experiment variant') + args = parser.parse_args() + + if args.variant is None: + args.variant = 'ace' + + run_command(args) + +def run_command(args): + fig = Figure(figsize=(4.5,3.5)) + FigureCanvas(fig) + ax = fig.add_subplot(111) + for l2_lambda in args.l2_lambda: + variant = args.variant + if l2_lambda != 0.01: + variant += '_reg%g' % l2_lambda + + dirname = os.path.join(args.outdir, args.layer, variant, args.classname) + snapshots = os.path.join(dirname, 'snapshots') + try: + dat = [torch.load(os.path.join(snapshots, 'epoch-%d.pth' % i)) + for i in range(10)] + except: + print('Missing %s snapshots' % dirname) + return + print('reg %g' % l2_lambda) + for i in range(10): + print(i, dat[i]['avg_loss'], + len((dat[i]['ablation'] == 1).nonzero())) + + ax.plot([dat[i]['avg_loss'] for i in range(10)], + label='reg %g' % l2_lambda) + ax.set_title('%s %s' % (args.classname, args.variant)) + ax.grid(True) + ax.legend() + ax.set_ylabel('Loss') + ax.set_xlabel('Epochs') + fig.tight_layout() + dirname = os.path.join(args.outdir, args.layer, + args.variant, args.classname) + fig.savefig(os.path.join(dirname, 'loss-plot.png')) + +if __name__ == '__main__': + main() diff --git a/netdissect/actviz.py b/netdissect/actviz.py new file mode 100644 index 0000000000000000000000000000000000000000..060ea13d589544ce936ac7c7bc20cd35194d0ae9 --- /dev/null +++ b/netdissect/actviz.py @@ -0,0 +1,187 @@ +import os +import numpy +from scipy.interpolate import RectBivariateSpline + +def activation_visualization(image, data, level, alpha=0.5, source_shape=None, + crop=False, zoom=None, border=2, negate=False, return_mask=False, + **kwargs): + """ + Makes a visualiztion image of activation data overlaid on the image. + Params: + image The original image. + data The single channel feature map. + alpha The darkening to apply in inactive regions of the image. + level The threshold of activation levels to highlight. + """ + if len(image.shape) == 2: + # Puff up grayscale image to RGB. + image = image[:,:,None] * numpy.array([[[1, 1, 1]]]) + surface = activation_surface(data, target_shape=image.shape[:2], + source_shape=source_shape, **kwargs) + if negate: + surface = -surface + level = -level + if crop: + # crop to source_shape + if source_shape is not None: + ch, cw = ((t - s) // 2 for s, t in zip( + source_shape, image.shape[:2])) + image = image[ch:ch+source_shape[0], cw:cw+source_shape[1]] + surface = surface[ch:ch+source_shape[0], cw:cw+source_shape[1]] + if crop is True: + crop = surface.shape + elif not hasattr(crop, '__len__'): + crop = (crop, crop) + if zoom is not None: + source_rect = best_sub_rect(surface >= level, crop, zoom, + pad=border) + else: + source_rect = (0, surface.shape[0], 0, surface.shape[1]) + image = zoom_image(image, source_rect, crop) + surface = zoom_image(surface, source_rect, crop) + mask = (surface >= level) + # Add a yellow border at the edge of the mask for contrast + result = (mask[:, :, None] * (1 - alpha) + alpha) * image + if border: + edge = mask_border(mask)[:,:,None] + result = numpy.maximum(edge * numpy.array([[[200, 200, 0]]]), result) + if not return_mask: + return result + mask_image = (1 - mask[:, :, None]) * numpy.array( + [[[0, 0, 0, 255 * (1 - alpha)]]], dtype=numpy.uint8) + if border: + mask_image = numpy.maximum(edge * numpy.array([[[200, 200, 0, 255]]]), + mask_image) + return result, mask_image + +def activation_surface(data, target_shape=None, source_shape=None, + scale_offset=None, deg=1, pad=True): + """ + Generates an upsampled activation sample. + Params: + target_shape Shape of the output array. + source_shape The centered shape of the output to match with data + when upscaling. Defaults to the whole target_shape. + scale_offset The amount by which to scale, then offset data + dimensions to end up with target dimensions. A pair of pairs. + deg Degree of interpolation to apply (1 = linear, etc). + pad True to zero-pad the edge instead of doing a funny edge interp. + """ + # Default is that nothing is resized. + if target_shape is None: + target_shape = data.shape + # Make a default scale_offset to fill the image if there isn't one + if scale_offset is None: + scale = tuple(float(ts) / ds + for ts, ds in zip(target_shape, data.shape)) + offset = tuple(0.5 * s - 0.5 for s in scale) + else: + scale, offset = (v for v in zip(*scale_offset)) + # Now we adjust offsets to take into account cropping and so on + if source_shape is not None: + offset = tuple(o + (ts - ss) / 2.0 + for o, ss, ts in zip(offset, source_shape, target_shape)) + # Pad the edge with zeros for sensible edge behavior + if pad: + zeropad = numpy.zeros( + (data.shape[0] + 2, data.shape[1] + 2), dtype=data.dtype) + zeropad[1:-1, 1:-1] = data + data = zeropad + offset = tuple((o - s) for o, s in zip(offset, scale)) + # Upsample linearly + ty, tx = (numpy.arange(ts) for ts in target_shape) + sy, sx = (numpy.arange(ss) * s + o + for ss, s, o in zip(data.shape, scale, offset)) + levels = RectBivariateSpline( + sy, sx, data, kx=deg, ky=deg)(ty, tx, grid=True) + # Return the mask. + return levels + +def mask_border(mask, border=2): + """Given a mask computes a border mask""" + from scipy import ndimage + struct = ndimage.generate_binary_structure(2, 2) + erosion = numpy.ones((mask.shape[0] + 10, mask.shape[1] + 10), dtype='int') + erosion[5:5+mask.shape[0], 5:5+mask.shape[1]] = ~mask + for _ in range(border): + erosion = ndimage.binary_erosion(erosion, struct) + return ~mask ^ erosion[5:5+mask.shape[0], 5:5+mask.shape[1]] + +def bounding_rect(mask, pad=0): + """Returns (r, b, l, r) boundaries so that all nonzero pixels in mask + have locations (i, j) with t <= i < b, and l <= j < r.""" + nz = mask.nonzero() + if len(nz[0]) == 0: + # print('no pixels') + return (0, mask.shape[0], 0, mask.shape[1]) + (t, b), (l, r) = [(max(0, p.min() - pad), min(s, p.max() + 1 + pad)) + for p, s in zip(nz, mask.shape)] + return (t, b, l, r) + +def best_sub_rect(mask, shape, max_zoom=None, pad=2): + """Finds the smallest subrectangle containing all the nonzeros of mask, + matching the aspect ratio of shape, and where the zoom-up ratio is no + more than max_zoom""" + t, b, l, r = bounding_rect(mask, pad=pad) + height = max(b - t, int(round(float(shape[0]) * (r - l) / shape[1]))) + if max_zoom is not None: + height = int(max(round(float(shape[0]) / max_zoom), height)) + width = int(round(float(shape[1]) * height / shape[0])) + nt = min(mask.shape[0] - height, max(0, (b + t - height) // 2)) + nb = nt + height + nl = min(mask.shape[1] - width, max(0, (r + l - width) // 2)) + nr = nl + width + return (nt, nb, nl, nr) + +def zoom_image(img, source_rect, target_shape=None): + """Zooms pixels from the source_rect of img to target_shape.""" + import warnings + from scipy.ndimage import zoom + if target_shape is None: + target_shape = img.shape + st, sb, sl, sr = source_rect + source = img[st:sb, sl:sr] + if source.shape == target_shape: + return source + zoom_tuple = tuple(float(t) / s + for t, s in zip(target_shape, source.shape[:2]) + ) + (1,) * (img.ndim - 2) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) # "output shape of zoom" + target = zoom(source, zoom_tuple) + assert target.shape[:2] == target_shape, (target.shape, target_shape) + return target + +def scale_offset(dilations): + if len(dilations) == 0: + return (1, 0) + scale, offset = scale_offset(dilations[1:]) + kernel, stride, padding = dilations[0] + scale *= stride + offset *= stride + offset += (kernel - 1) / 2.0 - padding + return scale, offset + +def choose_level(feature_map, percentile=0.8): + ''' + Chooses the top 80% level (or whatever the level chosen). + ''' + data_range = numpy.sort(feature_map.flatten()) + return numpy.interp( + percentile, numpy.linspace(0, 1, len(data_range)), data_range) + +def dilations(modulelist): + result = [] + for module in modulelist: + settings = tuple(getattr(module, n, d) + for n, d in (('kernel_size', 1), ('stride', 1), ('padding', 0))) + settings = (((s, s) if not isinstance(s, tuple) else s) + for s in settings) + if settings != ((1, 1), (1, 1), (0, 0)): + result.append(zip(*settings)) + return zip(*result) + +def grid_scale_offset(modulelist): + '''Returns (yscale, yoffset), (xscale, xoffset) given a list of modules''' + return tuple(scale_offset(d) for d in dilations(modulelist)) + diff --git a/netdissect/autoeval.py b/netdissect/autoeval.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc86a1f7b403f57821dde2a2b4f0619c0d6cae3 --- /dev/null +++ b/netdissect/autoeval.py @@ -0,0 +1,37 @@ +from collections import defaultdict +from importlib import import_module + +def autoimport_eval(term): + ''' + Used to evaluate an arbitrary command-line constructor specifying + a class, with automatic import of global module names. + ''' + + class DictNamespace(object): + def __init__(self, d): + self.__d__ = d + def __getattr__(self, key): + return self.__d__[key] + + class AutoImportDict(defaultdict): + def __init__(self, wrapped=None, parent=None): + super().__init__() + self.wrapped = wrapped + self.parent = parent + def __missing__(self, key): + if self.wrapped is not None: + if key in self.wrapped: + return self.wrapped[key] + if self.parent is not None: + key = self.parent + '.' + key + if key in __builtins__: + return __builtins__[key] + mdl = import_module(key) + # Return an AutoImportDict for any namespace packages + if hasattr(mdl, '__path__'): # and not hasattr(mdl, '__file__'): + return DictNamespace( + AutoImportDict(wrapped=mdl.__dict__, parent=key)) + return mdl + + return eval(term, {}, AutoImportDict()) + diff --git a/netdissect/broden.py b/netdissect/broden.py new file mode 100644 index 0000000000000000000000000000000000000000..854e87a46839c837b43cba5347967ce74ae4bf35 --- /dev/null +++ b/netdissect/broden.py @@ -0,0 +1,271 @@ +import os, errno, numpy, torch, csv, re, shutil, os, zipfile +from collections import OrderedDict +from torchvision.datasets.folder import default_loader +from torchvision import transforms +from scipy import ndimage +from urllib.request import urlopen + +class BrodenDataset(torch.utils.data.Dataset): + ''' + A multicategory segmentation data set. + + Returns three streams: + (1) The image (3, h, w). + (2) The multicategory segmentation (labelcount, h, w). + (3) A bincount of pixels in the segmentation (labelcount). + + Net dissect also assumes that the dataset object has three properties + with human-readable labels: + + ds.labels = ['red', 'black', 'car', 'tree', 'grid', ...] + ds.categories = ['color', 'part', 'object', 'texture'] + ds.label_category = [0, 0, 2, 2, 3, ...] # The category for each label + ''' + def __init__(self, directory='dataset/broden', resolution=384, + split='train', categories=None, + transform=None, transform_segment=None, + download=False, size=None, include_bincount=True, + broden_version=1, max_segment_depth=6): + assert resolution in [224, 227, 384] + if download: + ensure_broden_downloaded(directory, resolution, broden_version) + self.directory = directory + self.resolution = resolution + self.resdir = os.path.join(directory, 'broden%d_%d' % + (broden_version, resolution)) + self.loader = default_loader + self.transform = transform + self.transform_segment = transform_segment + self.include_bincount = include_bincount + # The maximum number of multilabel layers that coexist at an image. + self.max_segment_depth = max_segment_depth + with open(os.path.join(self.resdir, 'category.csv'), + encoding='utf-8') as f: + self.category_info = OrderedDict() + for row in csv.DictReader(f): + self.category_info[row['name']] = row + if categories is not None: + # Filter out unused categories + categories = set([c for c in categories if c in self.category_info]) + for cat in list(self.category_info.keys()): + if cat not in categories: + del self.category_info[cat] + categories = list(self.category_info.keys()) + self.categories = categories + + # Filter out unneeded images. + with open(os.path.join(self.resdir, 'index.csv'), + encoding='utf-8') as f: + all_images = [decode_index_dict(r) for r in csv.DictReader(f)] + self.image = [row for row in all_images + if index_has_any_data(row, categories) and row['split'] == split] + if size is not None: + self.image = self.image[:size] + with open(os.path.join(self.resdir, 'label.csv'), + encoding='utf-8') as f: + self.label_info = build_dense_label_array([ + decode_label_dict(r) for r in csv.DictReader(f)]) + self.labels = [l['name'] for l in self.label_info] + # Build dense remapping arrays for labels, so that you can + # get dense ranges of labels for each category. + self.category_map = {} + self.category_unmap = {} + self.category_label = {} + for cat in self.categories: + with open(os.path.join(self.resdir, 'c_%s.csv' % cat), + encoding='utf-8') as f: + c_data = [decode_label_dict(r) for r in csv.DictReader(f)] + self.category_unmap[cat], self.category_map[cat] = ( + build_numpy_category_map(c_data)) + self.category_label[cat] = build_dense_label_array( + c_data, key='code') + self.num_labels = len(self.labels) + # Primary categories for each label is the category in which it + # appears with the maximum coverage. + self.label_category = numpy.zeros(self.num_labels, dtype=int) + for i in range(self.num_labels): + maxcoverage, self.label_category[i] = max( + (self.category_label[cat][self.category_map[cat][i]]['coverage'] + if i < len(self.category_map[cat]) + and self.category_map[cat][i] else 0, ic) + for ic, cat in enumerate(categories)) + + def __len__(self): + return len(self.image) + + def __getitem__(self, idx): + record = self.image[idx] + # example record: { + # 'image': 'opensurfaces/25605.jpg', 'split': 'train', + # 'ih': 384, 'iw': 384, 'sh': 192, 'sw': 192, + # 'color': ['opensurfaces/25605_color.png'], + # 'object': [], 'part': [], + # 'material': ['opensurfaces/25605_material.png'], + # 'scene': [], 'texture': []} + image = self.loader(os.path.join(self.resdir, 'images', + record['image'])) + segment = numpy.zeros(shape=(self.max_segment_depth, + record['sh'], record['sw']), dtype=int) + if self.include_bincount: + bincount = numpy.zeros(shape=(self.num_labels,), dtype=int) + depth = 0 + for cat in self.categories: + for layer in record[cat]: + if isinstance(layer, int): + segment[depth,:,:] = layer + if self.include_bincount: + bincount[layer] += segment.shape[1] * segment.shape[2] + else: + png = numpy.asarray(self.loader(os.path.join( + self.resdir, 'images', layer))) + segment[depth,:,:] = png[:,:,0] + png[:,:,1] * 256 + if self.include_bincount: + bincount += numpy.bincount(segment[depth,:,:].flatten(), + minlength=self.num_labels) + depth += 1 + if self.transform: + image = self.transform(image) + if self.transform_segment: + segment = self.transform_segment(segment) + if self.include_bincount: + bincount[0] = 0 + return (image, segment, bincount) + else: + return (image, segment) + +def build_dense_label_array(label_data, key='number', allow_none=False): + ''' + Input: set of rows with 'number' fields (or another field name key). + Output: array such that a[number] = the row with the given number. + ''' + result = [None] * (max([d[key] for d in label_data]) + 1) + for d in label_data: + result[d[key]] = d + # Fill in none + if not allow_none: + example = label_data[0] + def make_empty(k): + return dict((c, k if c is key else type(v)()) + for c, v in example.items()) + for i, d in enumerate(result): + if d is None: + result[i] = dict(make_empty(i)) + return result + +def build_numpy_category_map(map_data, key1='code', key2='number'): + ''' + Input: set of rows with 'number' fields (or another field name key). + Output: array such that a[number] = the row with the given number. + ''' + results = list(numpy.zeros((max([d[key] for d in map_data]) + 1), + dtype=numpy.int16) for key in (key1, key2)) + for d in map_data: + results[0][d[key1]] = d[key2] + results[1][d[key2]] = d[key1] + return results + +def index_has_any_data(row, categories): + for c in categories: + for data in row[c]: + if data: return True + return False + +def decode_label_dict(row): + result = {} + for key, val in row.items(): + if key == 'category': + result[key] = dict((c, int(n)) + for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups() + for f in val.split(';')]) + elif key == 'name': + result[key] = val + elif key == 'syns': + result[key] = val.split(';') + elif re.match('^\d+$', val): + result[key] = int(val) + elif re.match('^\d+\.\d*$', val): + result[key] = float(val) + else: + result[key] = val + return result + +def decode_index_dict(row): + result = {} + for key, val in row.items(): + if key in ['image', 'split']: + result[key] = val + elif key in ['sw', 'sh', 'iw', 'ih']: + result[key] = int(val) + else: + item = [s for s in val.split(';') if s] + for i, v in enumerate(item): + if re.match('^\d+$', v): + item[i] = int(v) + result[key] = item + return result + +class ScaleSegmentation: + ''' + Utility for scaling segmentations, using nearest-neighbor zooming. + ''' + def __init__(self, target_height, target_width): + self.target_height = target_height + self.target_width = target_width + def __call__(self, seg): + ratio = (1, self.target_height / float(seg.shape[1]), + self.target_width / float(seg.shape[2])) + return ndimage.zoom(seg, ratio, order=0) + +def scatter_batch(seg, num_labels, omit_zero=True, dtype=torch.uint8): + ''' + Utility for scattering semgentations into a one-hot representation. + ''' + result = torch.zeros(*((seg.shape[0], num_labels,) + seg.shape[2:]), + dtype=dtype, device=seg.device) + result.scatter_(1, seg, 1) + if omit_zero: + result[:,0] = 0 + return result + +def ensure_broden_downloaded(directory, resolution, broden_version=1): + assert resolution in [224, 227, 384] + baseurl = 'http://netdissect.csail.mit.edu/data/' + dirname = 'broden%d_%d' % (broden_version, resolution) + if os.path.isfile(os.path.join(directory, dirname, 'index.csv')): + return # Already downloaded + zipfilename = 'broden1_%d.zip' % resolution + download_dir = os.path.join(directory, 'download') + os.makedirs(download_dir, exist_ok=True) + full_zipfilename = os.path.join(download_dir, zipfilename) + if not os.path.exists(full_zipfilename): + url = '%s/%s' % (baseurl, zipfilename) + print('Downloading %s' % url) + data = urlopen(url) + with open(full_zipfilename, 'wb') as f: + f.write(data.read()) + print('Unzipping %s' % zipfilename) + with zipfile.ZipFile(full_zipfilename, 'r') as zip_ref: + zip_ref.extractall(directory) + assert os.path.isfile(os.path.join(directory, dirname, 'index.csv')) + +def test_broden_dataset(): + ''' + Testing code. + ''' + bds = BrodenDataset('dataset/broden', resolution=384, + transform=transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor()]), + transform_segment=transforms.Compose([ + ScaleSegmentation(224, 224) + ]), + include_bincount=True) + loader = torch.utils.data.DataLoader(bds, batch_size=100, num_workers=24) + for i in range(1,20): + print(bds.label[i]['name'], + list(bds.category.keys())[bds.primary_category[i]]) + for i, (im, seg, bc) in enumerate(loader): + print(i, im.shape, seg.shape, seg.max(), bc.shape) + +if __name__ == '__main__': + test_broden_dataset() diff --git a/netdissect/dissect.html b/netdissect/dissect.html new file mode 100644 index 0000000000000000000000000000000000000000..e6bf4e9a418abdfef5ba09c4182bd71cf1420e52 --- /dev/null +++ b/netdissect/dissect.html @@ -0,0 +1,399 @@ + + + + + + + + + + + + +
+ + + +
+
+ + + +
+ +
+
+{{lrec.interpretable}}/{{lrec.units.length}} units + +covering {{lrec.labels.length}} concepts +with IoU ≥ {{dissect.iou_threshold}} + +
+ +
+sort by + +{{rank.name}} + + + +
+
+ *-{{ metric }} +
+
+
+ {{rank.name}} +
+
+
+ +
+ +
+ +
+
+
{{urec[lk+'_label']}}
+
{{lrec.layer}} unit {{urec.unit}} ({{urec[lk+'_cat']}}) iou {{urec[lk + '_iou'] | fixed(2)}} {{lk}} {{urec[lk] | fixed(2)}}
+
+
+ +
+ +
+ + + + + + + diff --git a/netdissect/dissection.py b/netdissect/dissection.py new file mode 100644 index 0000000000000000000000000000000000000000..6eef0dfd0b8804e45eb878aca68e72f8c6493474 --- /dev/null +++ b/netdissect/dissection.py @@ -0,0 +1,1617 @@ +''' +To run dissection: + +1. Load up the convolutional model you wish to dissect, and wrap it in + an InstrumentedModel; then call imodel.retain_layers([layernames,..]) + to instrument the layers of interest. +2. Load the segmentation dataset using the BrodenDataset class; + use the transform_image argument to normalize images to be + suitable for the model, or the size argument to truncate the dataset. +3. Choose a directory in which to write the output, and call + dissect(outdir, model, dataset). + +Example: + + from dissect import InstrumentedModel, dissect + from broden import BrodenDataset + + model = InstrumentedModel(load_my_model()) + model.eval() + model.cuda() + model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5']) + bds = BrodenDataset('dataset/broden1_227', + transform_image=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]), + size=1000) + dissect('result/dissect', model, bds, + examples_per_unit=10) +''' + +import torch, numpy, os, re, json, shutil, types, tempfile, torchvision +# import warnings +# warnings.simplefilter('error', UserWarning) +from PIL import Image +from xml.etree import ElementTree as et +from collections import OrderedDict, defaultdict +from .progress import verbose_progress, default_progress, print_progress +from .progress import desc_progress +from .runningstats import RunningQuantile, RunningTopK +from .runningstats import RunningCrossCovariance, RunningConditionalQuantile +from .sampler import FixedSubsetSampler +from .actviz import activation_visualization +from .segviz import segment_visualization, high_contrast +from .workerpool import WorkerBase, WorkerPool +from .segmenter import UnifiedParsingSegmenter + +def dissect(outdir, model, dataset, + segrunner=None, + train_dataset=None, + model_segmenter=None, + quantile_threshold=0.005, + iou_threshold=0.05, + iqr_threshold=0.01, + examples_per_unit=100, + batch_size=100, + num_workers=24, + seg_batch_size=5, + make_images=True, + make_labels=True, + make_maxiou=False, + make_covariance=False, + make_report=True, + make_row_images=True, + make_single_images=False, + rank_all_labels=False, + netname=None, + meta=None, + merge=None, + settings=None, + ): + ''' + Runs net dissection in-memory, using pytorch, and saves visualizations + and metadata into outdir. + ''' + assert not model.training, 'Run model.eval() before dissection' + if netname is None: + netname = type(model).__name__ + if segrunner is None: + segrunner = ClassifierSegRunner(dataset) + if train_dataset is None: + train_dataset = dataset + make_iqr = (quantile_threshold == 'iqr') + with torch.no_grad(): + device = next(model.parameters()).device + levels = None + labelnames, catnames = None, None + maxioudata, iqrdata = None, None + labeldata = None + iqrdata, cov = None, None + + labelnames, catnames = segrunner.get_label_and_category_names() + label_category = [catnames.index(c) if c in catnames else 0 + for l, c in labelnames] + + # First, always collect qunatiles and topk information. + segloader = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, num_workers=num_workers, + pin_memory=(device.type == 'cuda')) + quantiles, topk = collect_quantiles_and_topk(outdir, model, + segloader, segrunner, k=examples_per_unit) + + # Thresholds can be automatically chosen by maximizing iqr + if make_iqr: + # Get thresholds based on an IQR optimization + segloader = torch.utils.data.DataLoader(train_dataset, + batch_size=1, num_workers=num_workers, + pin_memory=(device.type == 'cuda')) + iqrdata = collect_iqr(outdir, model, segloader, segrunner) + max_iqr, full_iqr_levels = iqrdata[:2] + max_iqr_agreement = iqrdata[4] + # qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0 + levels = {layer: full_iqr_levels[layer][ + max_iqr[layer].max(0)[1], + torch.arange(max_iqr[layer].shape[1])].to(device) + for layer in full_iqr_levels} + else: + levels = {k: qc.quantiles([1.0 - quantile_threshold])[:,0] + for k, qc in quantiles.items()} + + quantiledata = (topk, quantiles, levels, quantile_threshold) + + if make_images: + segloader = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, num_workers=num_workers, + pin_memory=(device.type == 'cuda')) + generate_images(outdir, model, dataset, topk, levels, segrunner, + row_length=examples_per_unit, batch_size=seg_batch_size, + row_images=make_row_images, + single_images=make_single_images, + num_workers=num_workers) + + if make_maxiou: + assert train_dataset, "Need training dataset for maxiou." + segloader = torch.utils.data.DataLoader(train_dataset, + batch_size=1, num_workers=num_workers, + pin_memory=(device.type == 'cuda')) + maxioudata = collect_maxiou(outdir, model, segloader, + segrunner) + + if make_labels: + segloader = torch.utils.data.DataLoader(dataset, + batch_size=1, num_workers=num_workers, + pin_memory=(device.type == 'cuda')) + iou_scores, iqr_scores, tcs, lcs, ccs, ics = ( + collect_bincounts(outdir, model, segloader, + levels, segrunner)) + labeldata = (iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, + iqr_threshold) + + if make_covariance: + segloader = torch.utils.data.DataLoader(dataset, + batch_size=seg_batch_size, + num_workers=num_workers, + pin_memory=(device.type == 'cuda')) + cov = collect_covariance(outdir, model, segloader, segrunner) + + if make_report: + generate_report(outdir, + quantiledata=quantiledata, + labelnames=labelnames, + catnames=catnames, + labeldata=labeldata, + maxioudata=maxioudata, + iqrdata=iqrdata, + covariancedata=cov, + rank_all_labels=rank_all_labels, + netname=netname, + meta=meta, + mergedata=merge, + settings=settings) + + return quantiledata, labeldata + +def generate_report(outdir, quantiledata, labelnames=None, catnames=None, + labeldata=None, maxioudata=None, iqrdata=None, covariancedata=None, + rank_all_labels=False, netname='Model', meta=None, settings=None, + mergedata=None): + ''' + Creates dissection.json reports and summary bargraph.svg files in the + specified output directory, and copies a dissection.html interface + to go along with it. + ''' + all_layers = [] + # Current source code directory, for html to copy. + srcdir = os.path.realpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) + # Unpack arguments + topk, quantiles, levels, quantile_threshold = quantiledata + top_record = dict( + netname=netname, + meta=meta, + default_ranking='unit', + quantile_threshold=quantile_threshold) + if settings is not None: + top_record['settings'] = settings + if labeldata is not None: + iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, iqr_threshold = ( + labeldata) + catorder = {'object': -7, 'scene': -6, 'part': -5, + 'piece': -4, + 'material': -3, 'texture': -2, 'color': -1} + for i, cat in enumerate(c for c in catnames if c not in catorder): + catorder[cat] = i + catnumber = {n: i for i, n in enumerate(catnames)} + catnumber['-'] = 0 + top_record['default_ranking'] = 'label' + top_record['iou_threshold'] = iou_threshold + top_record['iqr_threshold'] = iqr_threshold + labelnumber = dict((name[0], num) + for num, name in enumerate(labelnames)) + # Make a segmentation color dictionary + segcolors = {} + for i, name in enumerate(labelnames): + key = ','.join(str(s) for s in high_contrast[i % len(high_contrast)]) + if key in segcolors: + segcolors[key] += '/' + name[0] + else: + segcolors[key] = name[0] + top_record['segcolors'] = segcolors + for layer in topk.keys(): + units, rankings = [], [] + record = dict(layer=layer, units=units, rankings=rankings) + # For every unit, we always have basic visualization information. + topa, topi = topk[layer].result() + lev = levels[layer] + for u in range(len(topa)): + units.append(dict( + unit=u, + interp=True, + level=lev[u].item(), + top=[dict(imgnum=i.item(), maxact=a.item()) + for i, a in zip(topi[u], topa[u])], + )) + rankings.append(dict(name="unit", score=list([ + u for u in range(len(topa))]))) + # TODO: consider including stats and ranking based on quantiles, + # variance, connectedness here. + + # if we have labeldata, then every unit also gets a bunch of other info + if labeldata is not None: + lscore, qscore, cc, ic = [dat[layer] + for dat in [iou_scores, iqr_scores, ccs, ics]] + if iqrdata is not None: + # If we have IQR thresholds, assign labels based on that + max_iqr, max_iqr_level = iqrdata[:2] + best_label = max_iqr[layer].max(0)[1] + best_score = lscore[best_label, torch.arange(lscore.shape[1])] + best_qscore = qscore[best_label, torch.arange(lscore.shape[1])] + else: + # Otherwise, assign labels based on max iou + best_score, best_label = lscore.max(0) + best_qscore = qscore[best_label, torch.arange(qscore.shape[1])] + record['iou_threshold'] = iou_threshold, + for u, urec in enumerate(units): + score, qscore, label = ( + best_score[u], best_qscore[u], best_label[u]) + urec.update(dict( + iou=score.item(), + iou_iqr=qscore.item(), + lc=lcs[label].item(), + cc=cc[catnumber[labelnames[label][1]], u].item(), + ic=ic[label, u].item(), + interp=(qscore.item() > iqr_threshold and + score.item() > iou_threshold), + iou_labelnum=label.item(), + iou_label=labelnames[label.item()][0], + iou_cat=labelnames[label.item()][1], + )) + if maxioudata is not None: + max_iou, max_iou_level, max_iou_quantile = maxioudata + qualified_iou = max_iou[layer].clone() + # qualified_iou[max_iou_quantile[layer] > 0.75] = 0 + best_score, best_label = qualified_iou.max(0) + for u, urec in enumerate(units): + urec.update(dict( + maxiou=best_score[u].item(), + maxiou_label=labelnames[best_label[u].item()][0], + maxiou_cat=labelnames[best_label[u].item()][1], + maxiou_level=max_iou_level[layer][best_label[u], u].item(), + maxiou_quantile=max_iou_quantile[layer][ + best_label[u], u].item())) + if iqrdata is not None: + [max_iqr, max_iqr_level, max_iqr_quantile, + max_iqr_iou, max_iqr_agreement] = iqrdata + qualified_iqr = max_iqr[layer].clone() + qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0 + best_score, best_label = qualified_iqr.max(0) + for u, urec in enumerate(units): + urec.update(dict( + iqr=best_score[u].item(), + iqr_label=labelnames[best_label[u].item()][0], + iqr_cat=labelnames[best_label[u].item()][1], + iqr_level=max_iqr_level[layer][best_label[u], u].item(), + iqr_quantile=max_iqr_quantile[layer][ + best_label[u], u].item(), + iqr_iou=max_iqr_iou[layer][best_label[u], u].item() + )) + if covariancedata is not None: + score = covariancedata[layer].correlation() + best_score, best_label = score.max(1) + for u, urec in enumerate(units): + urec.update(dict( + cor=best_score[u].item(), + cor_label=labelnames[best_label[u].item()][0], + cor_cat=labelnames[best_label[u].item()][1] + )) + if mergedata is not None: + # Final step: if the user passed any data to merge into the + # units, merge them now. This can be used, for example, to + # indiate that a unit is not interpretable based on some + # outside analysis of unit statistics. + for lrec in mergedata.get('layers', []): + if lrec['layer'] == layer: + break + else: + lrec = None + for u, urec in enumerate(lrec.get('units', []) if lrec else []): + units[u].update(urec) + # After populating per-unit info, populate per-layer ranking info + if labeldata is not None: + # Collect all labeled units + labelunits = defaultdict(list) + all_labelunits = defaultdict(list) + for u, urec in enumerate(units): + if urec['interp']: + labelunits[urec['iou_labelnum']].append(u) + all_labelunits[urec['iou_labelnum']].append(u) + # Sort all units in order with most popular label first. + label_ordering = sorted(units, + # Sort by: + key=lambda r: (-1 if r['interp'] else 0, # interpretable + -len(labelunits[r['iou_labelnum']]), # label freq, score + -max([units[u]['iou'] + for u in labelunits[r['iou_labelnum']]], default=0), + r['iou_labelnum'], # label + -r['iou'])) # unit score + # Add label and iou ranking. + rankings.append(dict(name="label", score=(numpy.argsort(list( + ur['unit'] for ur in label_ordering))).tolist())) + rankings.append(dict(name="max iou", metric="iou", score=list( + -ur['iou'] for ur in units))) + # Add ranking for top labels + # for labelnum in [n for n in sorted( + # all_labelunits.keys(), key=lambda x: + # -len(all_labelunits[x])) if len(all_labelunits[n])]: + # label = labelnames[labelnum][0] + # rankings.append(dict(name="%s-iou" % label, + # concept=label, metric='iou', + # score=(-lscore[labelnum, :]).tolist())) + # Collate labels by category then frequency. + record['labels'] = [dict( + label=labelnames[label][0], + labelnum=label, + units=labelunits[label], + cat=labelnames[label][1]) + for label in (sorted(labelunits.keys(), + # Sort by: + key=lambda l: (catorder.get( # category + labelnames[l][1], 0), + -len(labelunits[l]), # label freq + -max([units[u]['iou'] for u in labelunits[l]], + default=0) # score + ))) if len(labelunits[label])] + # Total number of interpretable units. + record['interpretable'] = sum(len(group['units']) + for group in record['labels']) + # Make a bargraph of labels + os.makedirs(os.path.join(outdir, safe_dir_name(layer)), + exist_ok=True) + catgroups = OrderedDict() + for _, cat in sorted([(v, k) for k, v in catorder.items()]): + catgroups[cat] = [] + for rec in record['labels']: + if rec['cat'] not in catgroups: + catgroups[rec['cat']] = [] + catgroups[rec['cat']].append(rec['label']) + make_svg_bargraph( + [rec['label'] for rec in record['labels']], + [len(rec['units']) for rec in record['labels']], + [(cat, len(group)) for cat, group in catgroups.items()], + filename=os.path.join(outdir, safe_dir_name(layer), + 'bargraph.svg')) + # Only show the bargraph if it is non-empty. + if len(record['labels']): + record['bargraph'] = 'bargraph.svg' + if maxioudata is not None: + rankings.append(dict(name="max maxiou", metric="maxiou", score=list( + -ur['maxiou'] for ur in units))) + if iqrdata is not None: + rankings.append(dict(name="max iqr", metric="iqr", score=list( + -ur['iqr'] for ur in units))) + if covariancedata is not None: + rankings.append(dict(name="max cor", metric="cor", score=list( + -ur['cor'] for ur in units))) + + all_layers.append(record) + # Now add the same rankings to every layer... + all_labels = None + if rank_all_labels: + all_labels = [name for name, cat in labelnames] + if labeldata is not None: + # Count layers+quadrants with a given label, and sort by freq + counted_labels = defaultdict(int) + for label in [ + re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', unitrec['iou_label']) + for record in all_layers for unitrec in record['units']]: + counted_labels[label] += 1 + if all_labels is None: + all_labels = [label for count, label in sorted((-v, k) + for k, v in counted_labels.items())] + for record in all_layers: + layer = record['layer'] + for label in all_labels: + labelnum = labelnumber[label] + record['rankings'].append(dict(name="%s-iou" % label, + concept=label, metric='iou', + score=(-iou_scores[layer][labelnum, :]).tolist())) + + if maxioudata is not None: + if all_labels is None: + counted_labels = defaultdict(int) + for label in [ + re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', + unitrec['maxiou_label']) + for record in all_layers for unitrec in record['units']]: + counted_labels[label] += 1 + all_labels = [label for count, label in sorted((-v, k) + for k, v in counted_labels.items())] + qualified_iou = max_iou[layer].clone() + qualified_iou[max_iou_quantile[layer] > 0.5] = 0 + for record in all_layers: + layer = record['layer'] + for label in all_labels: + labelnum = labelnumber[label] + record['rankings'].append(dict(name="%s-maxiou" % label, + concept=label, metric='maxiou', + score=(-qualified_iou[labelnum, :]).tolist())) + + if iqrdata is not None: + if all_labels is None: + counted_labels = defaultdict(int) + for label in [ + re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', + unitrec['iqr_label']) + for record in all_layers for unitrec in record['units']]: + counted_labels[label] += 1 + all_labels = [label for count, label in sorted((-v, k) + for k, v in counted_labels.items())] + # qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0 + for record in all_layers: + layer = record['layer'] + qualified_iqr = max_iqr[layer].clone() + for label in all_labels: + labelnum = labelnumber[label] + record['rankings'].append(dict(name="%s-iqr" % label, + concept=label, metric='iqr', + score=(-qualified_iqr[labelnum, :]).tolist())) + + if covariancedata is not None: + if all_labels is None: + counted_labels = defaultdict(int) + for label in [ + re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', + unitrec['cor_label']) + for record in all_layers for unitrec in record['units']]: + counted_labels[label] += 1 + all_labels = [label for count, label in sorted((-v, k) + for k, v in counted_labels.items())] + for record in all_layers: + layer = record['layer'] + score = covariancedata[layer].correlation() + for label in all_labels: + labelnum = labelnumber[label] + record['rankings'].append(dict(name="%s-cor" % label, + concept=label, metric='cor', + score=(-score[:, labelnum]).tolist())) + + for record in all_layers: + layer = record['layer'] + # Dump per-layer json inside per-layer directory + record['dirname'] = '.' + with open(os.path.join(outdir, safe_dir_name(layer), 'dissect.json'), + 'w') as jsonfile: + top_record['layers'] = [record] + json.dump(top_record, jsonfile, indent=1) + # Copy the per-layer html + shutil.copy(os.path.join(srcdir, 'dissect.html'), + os.path.join(outdir, safe_dir_name(layer), 'dissect.html')) + record['dirname'] = safe_dir_name(layer) + + # Dump all-layer json in parent directory + with open(os.path.join(outdir, 'dissect.json'), 'w') as jsonfile: + top_record['layers'] = all_layers + json.dump(top_record, jsonfile, indent=1) + # Copy the all-layer html + shutil.copy(os.path.join(srcdir, 'dissect.html'), + os.path.join(outdir, 'dissect.html')) + shutil.copy(os.path.join(srcdir, 'edit.html'), + os.path.join(outdir, 'edit.html')) + + +def generate_images(outdir, model, dataset, topk, levels, + segrunner, row_length=None, gap_pixels=5, + row_images=True, single_images=False, prefix='', + batch_size=100, num_workers=24): + ''' + Creates an image strip file for every unit of every retained layer + of the model, in the format [outdir]/[layername]/[unitnum]-top.jpg. + Assumes that the indexes of topk refer to the indexes of dataset. + Limits each strip to the top row_length images. + ''' + progress = default_progress() + needed_images = {} + if row_images is False: + row_length = 1 + # Pass 1: needed_images lists all images that are topk for some unit. + for layer in topk: + topresult = topk[layer].result()[1].cpu() + for unit, row in enumerate(topresult): + for rank, imgnum in enumerate(row[:row_length]): + imgnum = imgnum.item() + if imgnum not in needed_images: + needed_images[imgnum] = [] + needed_images[imgnum].append((layer, unit, rank)) + levels = {k: v.cpu().numpy() for k, v in levels.items()} + row_length = len(row[:row_length]) + needed_sample = FixedSubsetSampler(sorted(needed_images.keys())) + device = next(model.parameters()).device + segloader = torch.utils.data.DataLoader(dataset, + batch_size=batch_size, num_workers=num_workers, + pin_memory=(device.type == 'cuda'), + sampler=needed_sample) + vizgrid, maskgrid, origrid, seggrid = [{} for _ in range(4)] + # Pass 2: populate vizgrid with visualizations of top units. + pool = None + for i, batch in enumerate( + progress(segloader, desc='Making images')): + # Reverse transformation to get the image in byte form. + seg, _, byte_im, _ = segrunner.run_and_segment_batch(batch, model, + want_rgb=True) + torch_features = model.retained_features() + scale_offset = getattr(model, 'scale_offset', None) + if pool is None: + # Distribute the work across processes: create shared mmaps. + for layer, tf in torch_features.items(): + [vizgrid[layer], maskgrid[layer], origrid[layer], + seggrid[layer]] = [ + create_temp_mmap_grid((tf.shape[1], + byte_im.shape[1], row_length, + byte_im.shape[2] + gap_pixels, depth), + dtype='uint8', + fill=255) + for depth in [3, 4, 3, 3]] + # Pass those mmaps to worker processes. + pool = WorkerPool(worker=VisualizeImageWorker, + memmap_grid_info=[ + {layer: (g.filename, g.shape, g.dtype) + for layer, g in grid.items()} + for grid in [vizgrid, maskgrid, origrid, seggrid]]) + byte_im = byte_im.cpu().numpy() + numpy_seg = seg.cpu().numpy() + features = {} + for index in range(len(byte_im)): + imgnum = needed_sample.samples[index + i*segloader.batch_size] + for layer, unit, rank in needed_images[imgnum]: + if layer not in features: + features[layer] = torch_features[layer].cpu().numpy() + pool.add(layer, unit, rank, + byte_im[index], + features[layer][index, unit], + levels[layer][unit], + scale_offset[layer] if scale_offset else None, + numpy_seg[index]) + pool.join() + # Pass 3: save image strips as [outdir]/[layer]/[unitnum]-[top/orig].jpg + pool = WorkerPool(worker=SaveImageWorker) + for layer, vg in progress(vizgrid.items(), desc='Saving images'): + os.makedirs(os.path.join(outdir, safe_dir_name(layer), + prefix + 'image'), exist_ok=True) + if single_images: + os.makedirs(os.path.join(outdir, safe_dir_name(layer), + prefix + 's-image'), exist_ok=True) + og, sg, mg = origrid[layer], seggrid[layer], maskgrid[layer] + for unit in progress(range(len(vg)), desc='Units'): + for suffix, grid in [('top.jpg', vg), ('orig.jpg', og), + ('seg.png', sg), ('mask.png', mg)]: + strip = grid[unit].reshape( + (grid.shape[1], grid.shape[2] * grid.shape[3], + grid.shape[4])) + if row_images: + filename = os.path.join(outdir, safe_dir_name(layer), + prefix + 'image', '%d-%s' % (unit, suffix)) + pool.add(strip[:,:-gap_pixels,:].copy(), filename) + # Image.fromarray(strip[:,:-gap_pixels,:]).save(filename, + # optimize=True, quality=80) + if single_images: + single_filename = os.path.join(outdir, safe_dir_name(layer), + prefix + 's-image', '%d-%s' % (unit, suffix)) + pool.add(strip[:,:strip.shape[1] // row_length + - gap_pixels,:].copy(), single_filename) + # Image.fromarray(strip[:,:strip.shape[1] // row_length + # - gap_pixels,:]).save(single_filename, + # optimize=True, quality=80) + pool.join() + # Delete the shared memory map files + clear_global_shared_files([g.filename + for grid in [vizgrid, maskgrid, origrid, seggrid] + for g in grid.values()]) + +global_shared_files = {} +def create_temp_mmap_grid(shape, dtype, fill): + dtype = numpy.dtype(dtype) + filename = os.path.join(tempfile.mkdtemp(), 'temp-%s-%s.mmap' % + ('x'.join('%d' % s for s in shape), dtype.name)) + fid = open(filename, mode='w+b') + original = numpy.memmap(fid, dtype=dtype, mode='w+', shape=shape) + original.fid = fid + original[...] = fill + global_shared_files[filename] = original + return original + +def shared_temp_mmap_grid(filename, shape, dtype): + if filename not in global_shared_files: + global_shared_files[filename] = numpy.memmap( + filename, dtype=dtype, mode='r+', shape=shape) + return global_shared_files[filename] + +def clear_global_shared_files(filenames): + for fn in filenames: + if fn in global_shared_files: + del global_shared_files[fn] + try: + os.unlink(fn) + except OSError: + pass + +class VisualizeImageWorker(WorkerBase): + def setup(self, memmap_grid_info): + self.vizgrid, self.maskgrid, self.origrid, self.seggrid = [ + {layer: shared_temp_mmap_grid(*info) + for layer, info in grid.items()} + for grid in memmap_grid_info] + def work(self, layer, unit, rank, + byte_im, acts, level, scale_offset, seg): + self.origrid[layer][unit,:,rank,:byte_im.shape[0],:] = byte_im + [self.vizgrid[layer][unit,:,rank,:byte_im.shape[0],:], + self.maskgrid[layer][unit,:,rank,:byte_im.shape[0],:]] = ( + activation_visualization( + byte_im, + acts, + level, + scale_offset=scale_offset, + return_mask=True)) + self.seggrid[layer][unit,:,rank,:byte_im.shape[0],:] = ( + segment_visualization(seg, byte_im.shape[0:2])) + +class SaveImageWorker(WorkerBase): + def work(self, data, filename): + Image.fromarray(data).save(filename, optimize=True, quality=80) + +def score_tally_stats(label_category, tc, truth, cc, ic): + pred = cc[label_category] + total = tc[label_category][:, None] + truth = truth[:, None] + epsilon = 1e-20 # avoid division-by-zero + union = pred + truth - ic + iou = ic.double() / (union.double() + epsilon) + arr = torch.empty(size=(2, 2) + ic.shape, dtype=ic.dtype, device=ic.device) + arr[0, 0] = ic + arr[0, 1] = pred - ic + arr[1, 0] = truth - ic + arr[1, 1] = total - union + arr = arr.double() / total.double() + mi = mutual_information(arr) + je = joint_entropy(arr) + iqr = mi / je + iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0 + return iou, iqr + +def collect_quantiles_and_topk(outdir, model, segloader, + segrunner, k=100, resolution=1024): + ''' + Collects (estimated) quantile information and (exact) sorted top-K lists + for every channel in the retained layers of the model. Returns + a map of quantiles (one RunningQuantile for each layer) along with + a map of topk (one RunningTopK for each layer). + ''' + device = next(model.parameters()).device + features = model.retained_features() + cached_quantiles = { + layer: load_quantile_if_present(os.path.join(outdir, + safe_dir_name(layer)), 'quantiles.npz', + device=torch.device('cpu')) + for layer in features } + cached_topks = { + layer: load_topk_if_present(os.path.join(outdir, + safe_dir_name(layer)), 'topk.npz', + device=torch.device('cpu')) + for layer in features } + if (all(value is not None for value in cached_quantiles.values()) and + all(value is not None for value in cached_topks.values())): + return cached_quantiles, cached_topks + + layer_batch_size = 8 + all_layers = list(features.keys()) + layer_batches = [all_layers[i:i+layer_batch_size] + for i in range(0, len(all_layers), layer_batch_size)] + + quantiles, topks = {}, {} + progress = default_progress() + for layer_batch in layer_batches: + for i, batch in enumerate(progress(segloader, desc='Quantiles')): + # We don't actually care about the model output. + model(batch[0].to(device)) + features = model.retained_features() + # We care about the retained values + for key in layer_batch: + value = features[key] + if topks.get(key, None) is None: + topks[key] = RunningTopK(k) + if quantiles.get(key, None) is None: + quantiles[key] = RunningQuantile(resolution=resolution) + topvalue = value + if len(value.shape) > 2: + topvalue, _ = value.view(*(value.shape[:2] + (-1,))).max(2) + # Put the channel index last. + value = value.permute( + (0,) + tuple(range(2, len(value.shape))) + (1,) + ).contiguous().view(-1, value.shape[1]) + quantiles[key].add(value) + topks[key].add(topvalue) + # Save GPU memory + for key in layer_batch: + quantiles[key].to_(torch.device('cpu')) + topks[key].to_(torch.device('cpu')) + for layer in quantiles: + save_state_dict(quantiles[layer], + os.path.join(outdir, safe_dir_name(layer), 'quantiles.npz')) + save_state_dict(topks[layer], + os.path.join(outdir, safe_dir_name(layer), 'topk.npz')) + return quantiles, topks + +def collect_bincounts(outdir, model, segloader, levels, segrunner): + ''' + Returns label_counts, category_activation_counts, and intersection_counts, + across the data set, counting the pixels of intersection between upsampled, + thresholded model featuremaps, with segmentation classes in the segloader. + + label_counts (independent of model): pixels across the data set that + are labeled with the given label. + category_activation_counts (one per layer): for each feature channel, + pixels across the dataset where the channel exceeds the level + threshold. There is one count per category: activations only + contribute to the categories for which any category labels are + present on the images. + intersection_counts (one per layer): for each feature channel and + label, pixels across the dataset where the channel exceeds + the level, and the labeled segmentation class is also present. + + This is a performance-sensitive function. Best performance is + achieved with a counting scheme which assumes a segloader with + batch_size 1. + ''' + # Load cached data if present + (iou_scores, iqr_scores, + total_counts, label_counts, category_activation_counts, + intersection_counts) = {}, {}, None, None, {}, {} + found_all = True + for layer in model.retained_features(): + filename = os.path.join(outdir, safe_dir_name(layer), 'bincounts.npz') + if os.path.isfile(filename): + data = numpy.load(filename) + iou_scores[layer] = torch.from_numpy(data['iou_scores']) + iqr_scores[layer] = torch.from_numpy(data['iqr_scores']) + total_counts = torch.from_numpy(data['total_counts']) + label_counts = torch.from_numpy(data['label_counts']) + category_activation_counts[layer] = torch.from_numpy( + data['category_activation_counts']) + intersection_counts[layer] = torch.from_numpy( + data['intersection_counts']) + else: + found_all = False + if found_all: + return (iou_scores, iqr_scores, + total_counts, label_counts, category_activation_counts, + intersection_counts) + + device = next(model.parameters()).device + labelcat, categories = segrunner.get_label_and_category_names() + label_category = [categories.index(c) if c in categories else 0 + for l, c in labelcat] + num_labels, num_categories = (len(n) for n in [labelcat, categories]) + + # One-hot vector of category for each label + labelcat = torch.zeros(num_labels, num_categories, + dtype=torch.long, device=device) + labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category, + dtype='int64')).to(device)[:,None], 1) + # Running bincounts + # activation_counts = {} + assert segloader.batch_size == 1 # category_activation_counts needs this. + category_activation_counts = {} + intersection_counts = {} + label_counts = torch.zeros(num_labels, dtype=torch.long, device=device) + total_counts = torch.zeros(num_categories, dtype=torch.long, device=device) + progress = default_progress() + scale_offset_map = getattr(model, 'scale_offset', None) + upsample_grids = {} + # total_batch_categories = torch.zeros( + # labelcat.shape[1], dtype=torch.long, device=device) + for i, batch in enumerate(progress(segloader, desc='Bincounts')): + seg, batch_label_counts, _, imshape = segrunner.run_and_segment_batch( + batch, model, want_bincount=True, want_rgb=True) + bc = batch_label_counts.cpu() + batch_label_counts = batch_label_counts.to(device) + seg = seg.to(device) + features = model.retained_features() + # Accumulate bincounts and identify nonzeros + label_counts += batch_label_counts[0] + batch_labels = bc[0].nonzero()[:,0] + batch_categories = labelcat[batch_labels].max(0)[0] + total_counts += batch_categories * ( + seg.shape[0] * seg.shape[2] * seg.shape[3]) + for key, value in features.items(): + if key not in upsample_grids: + upsample_grids[key] = upsample_grid(value.shape[2:], + seg.shape[2:], imshape, + scale_offset=scale_offset_map.get(key, None) + if scale_offset_map is not None else None, + dtype=value.dtype, device=value.device) + upsampled = torch.nn.functional.grid_sample(value, + upsample_grids[key], padding_mode='border') + amask = (upsampled > levels[key][None,:,None,None].to( + upsampled.device)) + ac = amask.int().view(amask.shape[1], -1).sum(1) + # if key not in activation_counts: + # activation_counts[key] = ac + # else: + # activation_counts[key] += ac + # The fastest approach: sum over each label separately! + for label in batch_labels.tolist(): + if label == 0: + continue # ignore the background label + imask = amask * ((seg == label).max(dim=1, keepdim=True)[0]) + ic = imask.int().view(imask.shape[1], -1).sum(1) + if key not in intersection_counts: + intersection_counts[key] = torch.zeros(num_labels, + amask.shape[1], dtype=torch.long, device=device) + intersection_counts[key][label] += ic + # Count activations within images that have category labels. + # Note: This only makes sense with batch-size one + # total_batch_categories += batch_categories + cc = batch_categories[:,None] * ac[None,:] + if key not in category_activation_counts: + category_activation_counts[key] = cc + else: + category_activation_counts[key] += cc + iou_scores = {} + iqr_scores = {} + for k in intersection_counts: + iou_scores[k], iqr_scores[k] = score_tally_stats( + label_category, total_counts, label_counts, + category_activation_counts[k], intersection_counts[k]) + for k in intersection_counts: + numpy.savez(os.path.join(outdir, safe_dir_name(k), 'bincounts.npz'), + iou_scores=iou_scores[k].cpu().numpy(), + iqr_scores=iqr_scores[k].cpu().numpy(), + total_counts=total_counts.cpu().numpy(), + label_counts=label_counts.cpu().numpy(), + category_activation_counts=category_activation_counts[k] + .cpu().numpy(), + intersection_counts=intersection_counts[k].cpu().numpy(), + levels=levels[k].cpu().numpy()) + return (iou_scores, iqr_scores, + total_counts, label_counts, category_activation_counts, + intersection_counts) + +def collect_cond_quantiles(outdir, model, segloader, segrunner): + ''' + Returns maxiou and maxiou_level across the data set, one per layer. + + This is a performance-sensitive function. Best performance is + achieved with a counting scheme which assumes a segloader with + batch_size 1. + ''' + device = next(model.parameters()).device + cached_cond_quantiles = { + layer: load_conditional_quantile_if_present(os.path.join(outdir, + safe_dir_name(layer)), 'cond_quantiles.npz') # on cpu + for layer in model.retained_features() } + label_fracs = load_npy_if_present(outdir, 'label_fracs.npy', 'cpu') + if label_fracs is not None and all( + value is not None for value in cached_cond_quantiles.values()): + return cached_cond_quantiles, label_fracs + + labelcat, categories = segrunner.get_label_and_category_names() + label_category = [categories.index(c) if c in categories else 0 + for l, c in labelcat] + num_labels, num_categories = (len(n) for n in [labelcat, categories]) + + # One-hot vector of category for each label + labelcat = torch.zeros(num_labels, num_categories, + dtype=torch.long, device=device) + labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category, + dtype='int64')).to(device)[:,None], 1) + # Running maxiou + assert segloader.batch_size == 1 # category_activation_counts needs this. + conditional_quantiles = {} + label_counts = torch.zeros(num_labels, dtype=torch.long, device=device) + pixel_count = 0 + progress = default_progress() + scale_offset_map = getattr(model, 'scale_offset', None) + upsample_grids = {} + common_conditions = set() + if label_fracs is None or label_fracs is 0: + for i, batch in enumerate(progress(segloader, desc='label fracs')): + seg, batch_label_counts, im, _ = segrunner.run_and_segment_batch( + batch, model, want_bincount=True, want_rgb=True) + batch_label_counts = batch_label_counts.to(device) + features = model.retained_features() + # Accumulate bincounts and identify nonzeros + label_counts += batch_label_counts[0] + pixel_count += seg.shape[2] * seg.shape[3] + label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None] + numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs) + + skip_threshold = 1e-4 + skip_labels = set(i.item() + for i in (label_fracs.view(-1) < skip_threshold).nonzero().view(-1)) + + for layer in progress(model.retained_features().keys(), desc='CQ layers'): + if cached_cond_quantiles.get(layer, None) is not None: + conditional_quantiles[layer] = cached_cond_quantiles[layer] + continue + + for i, batch in enumerate(progress(segloader, desc='Condquant')): + seg, batch_label_counts, _, imshape = ( + segrunner.run_and_segment_batch( + batch, model, want_bincount=True, want_rgb=True)) + bc = batch_label_counts.cpu() + batch_label_counts = batch_label_counts.to(device) + features = model.retained_features() + # Accumulate bincounts and identify nonzeros + label_counts += batch_label_counts[0] + pixel_count += seg.shape[2] * seg.shape[3] + batch_labels = bc[0].nonzero()[:,0] + batch_categories = labelcat[batch_labels].max(0)[0] + cpu_seg = None + value = features[layer] + if layer not in upsample_grids: + upsample_grids[layer] = upsample_grid(value.shape[2:], + seg.shape[2:], imshape, + scale_offset=scale_offset_map.get(layer, None) + if scale_offset_map is not None else None, + dtype=value.dtype, device=value.device) + if layer not in conditional_quantiles: + conditional_quantiles[layer] = RunningConditionalQuantile( + resolution=2048) + upsampled = torch.nn.functional.grid_sample(value, + upsample_grids[layer], padding_mode='border').view( + value.shape[1], -1) + conditional_quantiles[layer].add(('all',), upsampled.t()) + cpu_upsampled = None + for label in batch_labels.tolist(): + if label in skip_labels: + continue + label_key = ('label', label) + if label_key in common_conditions: + imask = (seg == label).max(dim=1)[0].view(-1) + intersected = upsampled[:, imask] + conditional_quantiles[layer].add(('label', label), + intersected.t()) + else: + if cpu_seg is None: + cpu_seg = seg.cpu() + if cpu_upsampled is None: + cpu_upsampled = upsampled.cpu() + imask = (cpu_seg == label).max(dim=1)[0].view(-1) + intersected = cpu_upsampled[:, imask] + conditional_quantiles[layer].add(('label', label), + intersected.t()) + if num_categories > 1: + for cat in batch_categories.nonzero()[:,0]: + conditional_quantiles[layer].add(('cat', cat.item()), + upsampled.t()) + # Move the most common conditions to the GPU. + if i and not i & (i - 1): # if i is a power of 2: + cq = conditional_quantiles[layer] + common_conditions = set(cq.most_common_conditions(64)) + cq.to_('cpu', [k for k in cq.running_quantiles.keys() + if k not in common_conditions]) + # When a layer is done, get it off the GPU + conditional_quantiles[layer].to_('cpu') + + label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None] + + for cq in conditional_quantiles.values(): + cq.to_('cpu') + + for layer in conditional_quantiles: + save_state_dict(conditional_quantiles[layer], + os.path.join(outdir, safe_dir_name(layer), 'cond_quantiles.npz')) + numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs) + + return conditional_quantiles, label_fracs + + +def collect_maxiou(outdir, model, segloader, segrunner): + ''' + Returns maxiou and maxiou_level across the data set, one per layer. + + This is a performance-sensitive function. Best performance is + achieved with a counting scheme which assumes a segloader with + batch_size 1. + ''' + device = next(model.parameters()).device + conditional_quantiles, label_fracs = collect_cond_quantiles( + outdir, model, segloader, segrunner) + + labelcat, categories = segrunner.get_label_and_category_names() + label_category = [categories.index(c) if c in categories else 0 + for l, c in labelcat] + num_labels, num_categories = (len(n) for n in [labelcat, categories]) + + label_list = [('label', i) for i in range(num_labels)] + category_list = [('all',)] if num_categories <= 1 else ( + [('cat', i) for i in range(num_categories)]) + max_iou, max_iou_level, max_iou_quantile = {}, {}, {} + fracs = torch.logspace(-3, 0, 100) + progress = default_progress() + for layer, cq in progress(conditional_quantiles.items(), desc='Maxiou'): + levels = cq.conditional(('all',)).quantiles(1 - fracs) + denoms = 1 - cq.collected_normalize(category_list, levels) + isects = (1 - cq.collected_normalize(label_list, levels)) * label_fracs + unions = label_fracs + denoms[label_category, :, :] - isects + iou = isects / unions + # TODO: erase any for which threshold is bad + max_iou[layer], level_bucket = iou.max(2) + max_iou_level[layer] = levels[ + torch.arange(levels.shape[0])[None,:], level_bucket] + max_iou_quantile[layer] = fracs[level_bucket] + for layer in model.retained_features(): + numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'max_iou.npz'), + max_iou=max_iou[layer].cpu().numpy(), + max_iou_level=max_iou_level[layer].cpu().numpy(), + max_iou_quantile=max_iou_quantile[layer].cpu().numpy()) + return (max_iou, max_iou_level, max_iou_quantile) + +def collect_iqr(outdir, model, segloader, segrunner): + ''' + Returns iqr and iqr_level. + + This is a performance-sensitive function. Best performance is + achieved with a counting scheme which assumes a segloader with + batch_size 1. + ''' + max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou = {}, {}, {}, {} + max_iqr_agreement = {} + found_all = True + for layer in model.retained_features(): + filename = os.path.join(outdir, safe_dir_name(layer), 'iqr.npz') + if os.path.isfile(filename): + data = numpy.load(filename) + max_iqr[layer] = torch.from_numpy(data['max_iqr']) + max_iqr_level[layer] = torch.from_numpy(data['max_iqr_level']) + max_iqr_quantile[layer] = torch.from_numpy(data['max_iqr_quantile']) + max_iqr_iou[layer] = torch.from_numpy(data['max_iqr_iou']) + max_iqr_agreement[layer] = torch.from_numpy( + data['max_iqr_agreement']) + else: + found_all = False + if found_all: + return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou, + max_iqr_agreement) + + + device = next(model.parameters()).device + conditional_quantiles, label_fracs = collect_cond_quantiles( + outdir, model, segloader, segrunner) + + labelcat, categories = segrunner.get_label_and_category_names() + label_category = [categories.index(c) if c in categories else 0 + for l, c in labelcat] + num_labels, num_categories = (len(n) for n in [labelcat, categories]) + + label_list = [('label', i) for i in range(num_labels)] + category_list = [('all',)] if num_categories <= 1 else ( + [('cat', i) for i in range(num_categories)]) + full_mi, full_je, full_iqr = {}, {}, {} + fracs = torch.logspace(-3, 0, 100) + progress = default_progress() + for layer, cq in progress(conditional_quantiles.items(), desc='IQR'): + levels = cq.conditional(('all',)).quantiles(1 - fracs) + truth = label_fracs.to(device) + preds = (1 - cq.collected_normalize(category_list, levels) + )[label_category, :, :].to(device) + cond_isects = 1 - cq.collected_normalize(label_list, levels).to(device) + isects = cond_isects * truth + unions = truth + preds - isects + arr = torch.empty(size=(2, 2) + isects.shape, dtype=isects.dtype, + device=device) + arr[0, 0] = isects + arr[0, 1] = preds - isects + arr[1, 0] = truth - isects + arr[1, 1] = 1 - unions + arr.clamp_(0, 1) + mi = mutual_information(arr) + mi[:,:,-1] = 0 # at the 1.0 quantile should be no MI. + # Don't trust mi when less than label_frac is less than 1e-3, + # because our samples are too small. + mi[label_fracs.view(-1) < 1e-3, :, :] = 0 + je = joint_entropy(arr) + iqr = mi / je + iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0 + full_mi[layer] = mi.cpu() + full_je[layer] = je.cpu() + full_iqr[layer] = iqr.cpu() + del mi, je + agreement = isects + arr[1, 1] + # When optimizing, maximize only over those pairs where the + # unit is positively correlated with the label, and where the + # threshold level is positive + positive_iqr = iqr + positive_iqr[agreement <= 0.8] = 0 + positive_iqr[(levels <= 0.0)[None, :, :].expand(positive_iqr.shape)] = 0 + # TODO: erase any for which threshold is bad + maxiqr, level_bucket = positive_iqr.max(2) + max_iqr[layer] = maxiqr.cpu() + max_iqr_level[layer] = levels.to(device)[ + torch.arange(levels.shape[0])[None,:], level_bucket].cpu() + max_iqr_quantile[layer] = fracs.to(device)[level_bucket].cpu() + max_iqr_agreement[layer] = agreement[ + torch.arange(agreement.shape[0])[:, None], + torch.arange(agreement.shape[1])[None, :], + level_bucket].cpu() + + # Compute the iou that goes with each maximized iqr + matching_iou = (isects[ + torch.arange(isects.shape[0])[:, None], + torch.arange(isects.shape[1])[None, :], + level_bucket] / + unions[ + torch.arange(unions.shape[0])[:, None], + torch.arange(unions.shape[1])[None, :], + level_bucket]) + matching_iou[torch.isnan(matching_iou)] = 0 + max_iqr_iou[layer] = matching_iou.cpu() + for layer in model.retained_features(): + numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'iqr.npz'), + max_iqr=max_iqr[layer].cpu().numpy(), + max_iqr_level=max_iqr_level[layer].cpu().numpy(), + max_iqr_quantile=max_iqr_quantile[layer].cpu().numpy(), + max_iqr_iou=max_iqr_iou[layer].cpu().numpy(), + max_iqr_agreement=max_iqr_agreement[layer].cpu().numpy(), + full_mi=full_mi[layer].cpu().numpy(), + full_je=full_je[layer].cpu().numpy(), + full_iqr=full_iqr[layer].cpu().numpy()) + return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou, + max_iqr_agreement) + +def mutual_information(arr): + total = 0 + for j in range(arr.shape[0]): + for k in range(arr.shape[1]): + joint = arr[j,k] + ind = arr[j,:].sum(dim=0) * arr[:,k].sum(dim=0) + term = joint * (joint / ind).log() + term[torch.isnan(term)] = 0 + total += term + return total.clamp_(0) + +def joint_entropy(arr): + total = 0 + for j in range(arr.shape[0]): + for k in range(arr.shape[1]): + joint = arr[j,k] + term = joint * joint.log() + term[torch.isnan(term)] = 0 + total += term + return (-total).clamp_(0) + +def information_quality_ratio(arr): + iqr = mutual_information(arr) / joint_entropy(arr) + iqr[torch.isnan(iqr)] = 0 + return iqr + +def collect_covariance(outdir, model, segloader, segrunner): + ''' + Returns label_mean, label_variance, unit_mean, unit_variance, + and cross_covariance across the data set. + + label_mean, label_variance (independent of model): + treating the label as a one-hot, each label's mean and variance. + unit_mean, unit_variance (one per layer): for each feature channel, + the mean and variance of the activations in that channel. + cross_covariance (one per layer): the cross covariance between the + labels and the units in the layer. + ''' + device = next(model.parameters()).device + cached_covariance = { + layer: load_covariance_if_present(os.path.join(outdir, + safe_dir_name(layer)), 'covariance.npz', device=device) + for layer in model.retained_features() } + if all(value is not None for value in cached_covariance.values()): + return cached_covariance + labelcat, categories = segrunner.get_label_and_category_names() + label_category = [categories.index(c) if c in categories else 0 + for l, c in labelcat] + num_labels, num_categories = (len(n) for n in [labelcat, categories]) + + # Running covariance + cov = {} + progress = default_progress() + scale_offset_map = getattr(model, 'scale_offset', None) + upsample_grids = {} + for i, batch in enumerate(progress(segloader, desc='Covariance')): + seg, _, _, imshape = segrunner.run_and_segment_batch(batch, model, + want_rgb=True) + features = model.retained_features() + ohfeats = multilabel_onehot(seg, num_labels, ignore_index=0) + # Accumulate bincounts and identify nonzeros + for key, value in features.items(): + if key not in upsample_grids: + upsample_grids[key] = upsample_grid(value.shape[2:], + seg.shape[2:], imshape, + scale_offset=scale_offset_map.get(key, None) + if scale_offset_map is not None else None, + dtype=value.dtype, device=value.device) + upsampled = torch.nn.functional.grid_sample(value, + upsample_grids[key].expand( + (value.shape[0],) + upsample_grids[key].shape[1:]), + padding_mode='border') + if key not in cov: + cov[key] = RunningCrossCovariance() + cov[key].add(upsampled, ohfeats) + for layer in cov: + save_state_dict(cov[layer], + os.path.join(outdir, safe_dir_name(layer), 'covariance.npz')) + return cov + +def multilabel_onehot(labels, num_labels, dtype=None, ignore_index=None): + ''' + Converts a multilabel tensor into a onehot tensor. + + The input labels is a tensor of shape (samples, multilabels, y, x). + The output is a tensor of shape (samples, num_labels, y, x). + If ignore_index is specified, labels with that index are ignored. + Each x in labels should be 0 <= x < num_labels, or x == ignore_index. + ''' + assert ignore_index is None or ignore_index <= 0 + if dtype is None: + dtype = torch.float + device = labels.device + chans = num_labels + (-ignore_index if ignore_index else 0) + outshape = (labels.shape[0], chans) + labels.shape[2:] + result = torch.zeros(outshape, device=device, dtype=dtype) + if ignore_index and ignore_index < 0: + labels = labels + (-ignore_index) + result.scatter_(1, labels, 1) + if ignore_index and ignore_index < 0: + result = result[:, -ignore_index:] + elif ignore_index is not None: + result[:, ignore_index] = 0 + return result + +def load_npy_if_present(outdir, filename, device): + filepath = os.path.join(outdir, filename) + if os.path.isfile(filepath): + data = numpy.load(filepath) + return torch.from_numpy(data).to(device) + return 0 + +def load_npz_if_present(outdir, filename, varnames, device): + filepath = os.path.join(outdir, filename) + if os.path.isfile(filepath): + data = numpy.load(filepath) + numpy_result = [data[n] for n in varnames] + return tuple(torch.from_numpy(data).to(device) for data in numpy_result) + return None + +def load_quantile_if_present(outdir, filename, device): + filepath = os.path.join(outdir, filename) + if os.path.isfile(filepath): + data = numpy.load(filepath) + result = RunningQuantile(state=data) + result.to_(device) + return result + return None + +def load_conditional_quantile_if_present(outdir, filename): + filepath = os.path.join(outdir, filename) + if os.path.isfile(filepath): + data = numpy.load(filepath) + result = RunningConditionalQuantile(state=data) + return result + return None + +def load_topk_if_present(outdir, filename, device): + filepath = os.path.join(outdir, filename) + if os.path.isfile(filepath): + data = numpy.load(filepath) + result = RunningTopK(state=data) + result.to_(device) + return result + return None + +def load_covariance_if_present(outdir, filename, device): + filepath = os.path.join(outdir, filename) + if os.path.isfile(filepath): + data = numpy.load(filepath) + result = RunningCrossCovariance(state=data) + result.to_(device) + return result + return None + +def save_state_dict(obj, filepath): + dirname = os.path.dirname(filepath) + os.makedirs(dirname, exist_ok=True) + dic = obj.state_dict() + numpy.savez(filepath, **dic) + +def upsample_grid(data_shape, target_shape, input_shape=None, + scale_offset=None, dtype=torch.float, device=None): + '''Prepares a grid to use with grid_sample to upsample a batch of + features in data_shape to the target_shape. Can use scale_offset + and input_shape to center the grid in a nondefault way: scale_offset + maps feature pixels to input_shape pixels, and it is assumed that + the target_shape is a uniform downsampling of input_shape.''' + # Default is that nothing is resized. + if target_shape is None: + target_shape = data_shape + # Make a default scale_offset to fill the image if there isn't one + if scale_offset is None: + scale = tuple(float(ts) / ds + for ts, ds in zip(target_shape, data_shape)) + offset = tuple(0.5 * s - 0.5 for s in scale) + else: + scale, offset = (v for v in zip(*scale_offset)) + # Handle downsampling for different input vs target shape. + if input_shape is not None: + scale = tuple(s * (ts - 1) / (ns - 1) + for s, ns, ts in zip(scale, input_shape, target_shape)) + offset = tuple(o * (ts - 1) / (ns - 1) + for o, ns, ts in zip(offset, input_shape, target_shape)) + # Pytorch needs target coordinates in terms of source coordinates [-1..1] + ty, tx = (((torch.arange(ts, dtype=dtype, device=device) - o) + * (2 / (s * (ss - 1))) - 1) + for ts, ss, s, o, in zip(target_shape, data_shape, scale, offset)) + # Whoa, note that grid_sample reverses the order y, x -> x, y. + grid = torch.stack( + (tx[None,:].expand(target_shape), ty[:,None].expand(target_shape)),2 + )[None,:,:,:].expand((1, target_shape[0], target_shape[1], 2)) + return grid + +def safe_dir_name(filename): + keepcharacters = (' ','.','_','-') + return ''.join(c + for c in filename if c.isalnum() or c in keepcharacters).rstrip() + +bargraph_palette = [ + ('#4B4CBF', '#B6B6F2'), + ('#55B05B', '#B6F2BA'), + ('#50BDAC', '#A5E5DB'), + ('#81C679', '#C0FF9B'), + ('#F0883B', '#F2CFB6'), + ('#D4CF24', '#F2F1B6'), + ('#D92E2B', '#F2B6B6'), + ('#AB6BC6', '#CFAAFF'), +] + +def make_svg_bargraph(labels, heights, categories, + barheight=100, barwidth=12, show_labels=True, filename=None): + # if len(labels) == 0: + # return # Nothing to do + unitheight = float(barheight) / max(max(heights, default=1), 1) + textheight = barheight if show_labels else 0 + labelsize = float(barwidth) + gap = float(barwidth) / 4 + textsize = barwidth + gap + rollup = max(heights, default=1) + textmargin = float(labelsize) * 2 / 3 + leftmargin = 32 + rightmargin = 8 + svgwidth = len(heights) * (barwidth + gap) + 2 * leftmargin + rightmargin + svgheight = barheight + textheight + + # create an SVG XML element + svg = et.Element('svg', width=str(svgwidth), height=str(svgheight), + version='1.1', xmlns='http://www.w3.org/2000/svg') + + # Draw the bar graph + basey = svgheight - textheight + x = leftmargin + # Add units scale on left + if len(heights): + for h in [1, (max(heights) + 1) // 2, max(heights)]: + et.SubElement(svg, 'text', x='0', y='0', + style=('font-family:sans-serif;font-size:%dpx;' + + 'text-anchor:end;alignment-baseline:hanging;' + + 'transform:translate(%dpx, %dpx);') % + (textsize, x - gap, basey - h * unitheight)).text = str(h) + et.SubElement(svg, 'text', x='0', y='0', + style=('font-family:sans-serif;font-size:%dpx;' + + 'text-anchor:middle;' + + 'transform:translate(%dpx, %dpx) rotate(-90deg)') % + (textsize, x - gap - textsize, basey - h * unitheight / 2) + ).text = 'units' + # Draw big category background rectangles + for catindex, (cat, catcount) in enumerate(categories): + if not catcount: + continue + et.SubElement(svg, 'rect', x=str(x), y=str(basey - rollup * unitheight), + width=(str((barwidth + gap) * catcount - gap)), + height = str(rollup*unitheight), + fill=bargraph_palette[catindex % len(bargraph_palette)][1]) + x += (barwidth + gap) * catcount + # Draw small bars as well as 45degree text labels + x = leftmargin + catindex = -1 + catcount = 0 + for label, height in zip(labels, heights): + while not catcount and catindex <= len(categories): + catindex += 1 + catcount = categories[catindex][1] + color = bargraph_palette[catindex % len(bargraph_palette)][0] + et.SubElement(svg, 'rect', x=str(x), y=str(basey-(height * unitheight)), + width=str(barwidth), height=str(height * unitheight), + fill=color) + x += barwidth + if show_labels: + et.SubElement(svg, 'text', x='0', y='0', + style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+ + 'transform:translate(%dpx, %dpx) rotate(-45deg);') % + (labelsize, x, basey + textmargin)).text = readable(label) + x += gap + catcount -= 1 + # Text labels for each category + x = leftmargin + for cat, catcount in categories: + if not catcount: + continue + et.SubElement(svg, 'text', x='0', y='0', + style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+ + 'transform:translate(%dpx, %dpx) rotate(-90deg);') % + (textsize, x + (barwidth + gap) * catcount - gap, + basey - rollup * unitheight + gap)).text = '%d %s' % ( + catcount, readable(cat + ('s' if catcount != 1 else ''))) + x += (barwidth + gap) * catcount + # Output - this is the bare svg. + result = et.tostring(svg) + if filename: + f = open(filename, 'wb') + # When writing to a file a special header is needed. + f.write(''.join([ + '\n', + '\n'] + ).encode('utf-8')) + f.write(result) + f.close() + return result + +readable_replacements = [(re.compile(r[0]), r[1]) for r in [ + (r'-[sc]$', ''), + (r'_', ' '), + ]] + +def readable(label): + for pattern, subst in readable_replacements: + label= re.sub(pattern, subst, label) + return label + +def reverse_normalize_from_transform(transform): + ''' + Crawl around the transforms attached to a dataset looking for a + Normalize transform, and return it a corresponding ReverseNormalize, + or None if no normalization is found. + ''' + if isinstance(transform, torchvision.transforms.Normalize): + return ReverseNormalize(transform.mean, transform.std) + t = getattr(transform, 'transform', None) + if t is not None: + return reverse_normalize_from_transform(t) + transforms = getattr(transform, 'transforms', None) + if transforms is not None: + for t in reversed(transforms): + result = reverse_normalize_from_transform(t) + if result is not None: + return result + return None + +class ReverseNormalize: + ''' + Applies the reverse of torchvision.transforms.Normalize. + ''' + def __init__(self, mean, stdev): + mean = numpy.array(mean) + stdev = numpy.array(stdev) + self.mean = torch.from_numpy(mean)[None,:,None,None].float() + self.stdev = torch.from_numpy(stdev)[None,:,None,None].float() + def __call__(self, data): + device = data.device + return data.mul(self.stdev.to(device)).add_(self.mean.to(device)) + +class ImageOnlySegRunner: + def __init__(self, dataset, recover_image=None): + if recover_image is None: + recover_image = reverse_normalize_from_transform(dataset) + self.recover_image = recover_image + self.dataset = dataset + def get_label_and_category_names(self): + return [('-', '-')], ['-'] + def run_and_segment_batch(self, batch, model, + want_bincount=False, want_rgb=False): + [im] = batch + device = next(model.parameters()).device + if want_rgb: + rgb = self.recover_image(im.clone() + ).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte() + else: + rgb = None + # Stubs for seg and bc + seg = torch.zeros(im.shape[0], 1, 1, 1, dtype=torch.long) + bc = torch.ones(im.shape[0], 1, dtype=torch.long) + # Run the model. + model(im.to(device)) + return seg, bc, rgb, im.shape[2:] + +class ClassifierSegRunner: + def __init__(self, dataset, recover_image=None): + # The dataset contains explicit segmentations + if recover_image is None: + recover_image = reverse_normalize_from_transform(dataset) + self.recover_image = recover_image + self.dataset = dataset + def get_label_and_category_names(self): + catnames = self.dataset.categories + label_and_cat_names = [(readable(label), + catnames[self.dataset.label_category[i]]) + for i, label in enumerate(self.dataset.labels)] + return label_and_cat_names, catnames + def run_and_segment_batch(self, batch, model, + want_bincount=False, want_rgb=False): + ''' + Runs the dissected model on one batch of the dataset, and + returns a multilabel semantic segmentation for the data. + Given a batch of size (n, c, y, x) the segmentation should + be a (long integer) tensor of size (n, d, y//r, x//r) where + d is the maximum number of simultaneous labels given to a pixel, + and where r is some (optional) resolution reduction factor. + In the segmentation returned, the label `0` is reserved for + the background "no-label". + + In addition to the segmentation, bc, rgb, and shape are returned + where bc is a per-image bincount counting returned label pixels, + rgb is a viewable (n, y, x, rgb) byte image tensor for the data + for visualizations (reversing normalizations, for example), and + shape is the (y, x) size of the data. If want_bincount or + want_rgb are False, those return values may be None. + ''' + im, seg, bc = batch + device = next(model.parameters()).device + if want_rgb: + rgb = self.recover_image(im.clone() + ).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte() + else: + rgb = None + # Run the model. + model(im.to(device)) + return seg, bc, rgb, im.shape[2:] + +class GeneratorSegRunner: + def __init__(self, segmenter): + # The segmentations are given by an algorithm + if segmenter is None: + segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad') + self.segmenter = segmenter + self.num_classes = len(segmenter.get_label_and_category_names()[0]) + def get_label_and_category_names(self): + return self.segmenter.get_label_and_category_names() + def run_and_segment_batch(self, batch, model, + want_bincount=False, want_rgb=False): + ''' + Runs the dissected model on one batch of the dataset, and + returns a multilabel semantic segmentation for the data. + Given a batch of size (n, c, y, x) the segmentation should + be a (long integer) tensor of size (n, d, y//r, x//r) where + d is the maximum number of simultaneous labels given to a pixel, + and where r is some (optional) resolution reduction factor. + In the segmentation returned, the label `0` is reserved for + the background "no-label". + + In addition to the segmentation, bc, rgb, and shape are returned + where bc is a per-image bincount counting returned label pixels, + rgb is a viewable (n, y, x, rgb) byte image tensor for the data + for visualizations (reversing normalizations, for example), and + shape is the (y, x) size of the data. If want_bincount or + want_rgb are False, those return values may be None. + ''' + device = next(model.parameters()).device + z_batch = batch[0] + tensor_images = model(z_batch.to(device)) + seg = self.segmenter.segment_batch(tensor_images, downsample=2) + if want_bincount: + index = torch.arange(z_batch.shape[0], + dtype=torch.long, device=device) + bc = (seg + index[:, None, None, None] * self.num_classes).view(-1 + ).bincount(minlength=z_batch.shape[0] * self.num_classes) + bc = bc.view(z_batch.shape[0], self.num_classes) + else: + bc = None + if want_rgb: + images = ((tensor_images + 1) / 2 * 255) + rgb = images.permute(0, 2, 3, 1).clamp(0, 255).byte() + else: + rgb = None + return seg, bc, rgb, tensor_images.shape[2:] diff --git a/netdissect/easydict.py b/netdissect/easydict.py new file mode 100644 index 0000000000000000000000000000000000000000..0188f524b87eef75c175772ff262b93b47919ba7 --- /dev/null +++ b/netdissect/easydict.py @@ -0,0 +1,126 @@ +''' +From https://github.com/makinacorpus/easydict. +''' + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> map(attrgetter('x'), d.bar) + [1, 3] + >>> map(attrgetter('y'), d.bar) + [2, 4] + >>> d = EasyDict() + >>> d.keys() + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> o.items() + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + +def load_json(filename): + import json + with open(filename) as f: + return EasyDict(json.load(f)) + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/netdissect/edit.html b/netdissect/edit.html new file mode 100644 index 0000000000000000000000000000000000000000..9aac30bb08171c4c58eb936f9ba382e85a184803 --- /dev/null +++ b/netdissect/edit.html @@ -0,0 +1,805 @@ + + + + + + + + + + + + + +
+ + + + + + + + +
+
+

+ + +

+
+ + +
+ + +
+ +
+
+ +
+ +
+ +
{{urec.layer}} {{urec.unit}} +
+
+ +
+
+
+ +
+ +
+ +
+ + + + + +
+ +
+

Seeds to generate

+

+To transfer activations from one pixel to another (1) click on a source pixel +on the left image and (2) click on a target pixel on a right image, +then (3) choose a set of units to insert in the palette.

+
+
#{{ ex.id }}
+
+
+ +
+ +
+ +
+ + + + diff --git a/netdissect/evalablate.py b/netdissect/evalablate.py new file mode 100644 index 0000000000000000000000000000000000000000..2079ffdb303b288df77678109f701e40fdf5779b --- /dev/null +++ b/netdissect/evalablate.py @@ -0,0 +1,248 @@ +import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL +from torchvision import transforms +from torch.utils.data import TensorDataset +from netdissect.progress import default_progress, post_progress, desc_progress +from netdissect.progress import verbose_progress, print_progress +from netdissect.nethook import edit_layers +from netdissect.zdataset import standard_z_sample +from netdissect.autoeval import autoimport_eval +from netdissect.easydict import EasyDict +from netdissect.modelconfig import create_instrumented_model + +help_epilog = '''\ +Example: + +python -m netdissect.evalablate \ + --segmenter "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')" \ + --model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \ + --outdir dissect/dissectdir \ + --classes mirror coffeetable tree \ + --layers layer4 \ + --size 1000 + +Output layout: +dissectdir/layer5/ablation/mirror-iqr.json +{ class: "mirror", + classnum: 43, + pixel_total: 41342300, + class_pixels: 1234531, + layer: "layer5", + ranking: "mirror-iqr", + ablation_units: [341, 23, 12, 142, 83, ...] + ablation_pixels: [143242, 132344, 429931, ...] +} + +''' + +def main(): + # Training settings + def strpair(arg): + p = tuple(arg.split(':')) + if len(p) == 1: + p = p + p + return p + + parser = argparse.ArgumentParser(description='Ablation eval', + epilog=textwrap.dedent(help_epilog), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument('--model', type=str, default=None, + help='constructor for the model to test') + parser.add_argument('--pthfile', type=str, default=None, + help='filename of .pth file for the model') + parser.add_argument('--outdir', type=str, default='dissect', required=True, + help='directory for dissection output') + parser.add_argument('--layers', type=strpair, nargs='+', + help='space-separated list of layer names to edit' + + ', in the form layername[:reportedname]') + parser.add_argument('--classes', type=str, nargs='+', + help='space-separated list of class names to ablate') + parser.add_argument('--metric', type=str, default='iou', + help='ordering metric for selecting units') + parser.add_argument('--unitcount', type=int, default=30, + help='number of units to ablate') + parser.add_argument('--segmenter', type=str, + help='directory containing segmentation dataset') + parser.add_argument('--netname', type=str, default=None, + help='name for network in generated reports') + parser.add_argument('--batch_size', type=int, default=5, + help='batch size for forward pass') + parser.add_argument('--size', type=int, default=200, + help='number of images to test') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA usage') + parser.add_argument('--quiet', action='store_true', default=False, + help='silences console output') + if len(sys.argv) == 1: + parser.print_usage(sys.stderr) + sys.exit(1) + args = parser.parse_args() + + # Set up console output + verbose_progress(not args.quiet) + + # Speed up pytorch + torch.backends.cudnn.benchmark = True + + # Set up CUDA + args.cuda = not args.no_cuda and torch.cuda.is_available() + if args.cuda: + torch.backends.cudnn.benchmark = True + + # Take defaults for model constructor etc from dissect.json settings. + with open(os.path.join(args.outdir, 'dissect.json')) as f: + dissection = EasyDict(json.load(f)) + if args.model is None: + args.model = dissection.settings.model + if args.pthfile is None: + args.pthfile = dissection.settings.pthfile + if args.segmenter is None: + args.segmenter = dissection.settings.segmenter + + # Instantiate generator + model = create_instrumented_model(args, gen=True, edit=True) + if model is None: + print('No model specified') + sys.exit(1) + + # Instantiate model + device = next(model.parameters()).device + input_shape = model.input_shape + + # 4d input if convolutional, 2d input if first layer is linear. + raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view( + (args.size,) + input_shape[1:]) + dataset = TensorDataset(raw_sample) + + # Create the segmenter + segmenter = autoimport_eval(args.segmenter) + + # Now do the actual work. + labelnames, catnames = ( + segmenter.get_label_and_category_names(dataset)) + label_category = [catnames.index(c) if c in catnames else 0 + for l, c in labelnames] + labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)} + + segloader = torch.utils.data.DataLoader(dataset, + batch_size=args.batch_size, num_workers=10, + pin_memory=(device.type == 'cuda')) + + # Index the dissection layers by layer name. + dissect_layer = {lrec.layer: lrec for lrec in dissection.layers} + + # First, collect a baseline + for l in model.ablation: + model.ablation[l] = None + + # For each sort-order, do an ablation + progress = default_progress() + for classname in progress(args.classes): + post_progress(c=classname) + for layername in progress(model.ablation): + post_progress(l=layername) + rankname = '%s-%s' % (classname, args.metric) + classnum = labelnum_from_name[classname] + try: + ranking = next(r for r in dissect_layer[layername].rankings + if r.name == rankname) + except: + print('%s not found' % rankname) + sys.exit(1) + ordering = numpy.argsort(ranking.score) + # Check if already done + ablationdir = os.path.join(args.outdir, layername, 'pixablation') + if os.path.isfile(os.path.join(ablationdir, '%s.json'%rankname)): + with open(os.path.join(ablationdir, '%s.json'%rankname)) as f: + data = EasyDict(json.load(f)) + # If the unit ordering is not the same, something is wrong + if not all(a == o + for a, o in zip(data.ablation_units, ordering)): + continue + if len(data.ablation_effects) >= args.unitcount: + continue # file already done. + measurements = data.ablation_effects + measurements = measure_ablation(segmenter, segloader, + model, classnum, layername, ordering[:args.unitcount]) + measurements = measurements.cpu().numpy().tolist() + os.makedirs(ablationdir, exist_ok=True) + with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f: + json.dump(dict( + classname=classname, + classnum=classnum, + baseline=measurements[0], + layer=layername, + metric=args.metric, + ablation_units=ordering.tolist(), + ablation_effects=measurements[1:]), f) + +def measure_ablation(segmenter, loader, model, classnum, layer, ordering): + total_bincount = 0 + data_size = 0 + device = next(model.parameters()).device + progress = default_progress() + for l in model.ablation: + model.ablation[l] = None + feature_units = model.feature_shape[layer][1] + feature_shape = model.feature_shape[layer][2:] + repeats = len(ordering) + total_scores = torch.zeros(repeats + 1) + for i, batch in enumerate(progress(loader)): + z_batch = batch[0] + model.ablation[layer] = None + tensor_images = model(z_batch.to(device)) + seg = segmenter.segment_batch(tensor_images, downsample=2) + mask = (seg == classnum).max(1)[0] + downsampled_seg = torch.nn.functional.adaptive_avg_pool2d( + mask.float()[:,None,:,:], feature_shape)[:,0,:,:] + total_scores[0] += downsampled_seg.sum().cpu() + # Now we need to do an intervention for every location + # that had a nonzero downsampled_seg, if any. + interventions_needed = downsampled_seg.nonzero() + location_count = len(interventions_needed) + if location_count == 0: + continue + interventions_needed = interventions_needed.repeat(repeats, 1) + inter_z = batch[0][interventions_needed[:,0]].to(device) + inter_chan = torch.zeros(repeats, location_count, feature_units, + device=device) + for j, u in enumerate(ordering): + inter_chan[j:, :, u] = 1 + inter_chan = inter_chan.view(len(inter_z), feature_units) + inter_loc = interventions_needed[:,1:] + scores = torch.zeros(len(inter_z)) + batch_size = len(batch[0]) + for j in range(0, len(inter_z), batch_size): + ibz = inter_z[j:j+batch_size] + ibl = inter_loc[j:j+batch_size].t() + imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device) + imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1 + ibc = inter_chan[j:j+batch_size] + model.ablation[layer] = ( + imask.float()[:,None,:,:] * ibc[:,:,None,None]) + tensor_images = model(ibz) + seg = segmenter.segment_batch(tensor_images, downsample=2) + mask = (seg == classnum).max(1)[0] + downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d( + mask.float()[:,None,:,:], feature_shape)[:,0,:,:] + scores[j:j+batch_size] = downsampled_iseg[ + (torch.arange(len(ibz)),) + tuple(ibl)] + scores = scores.view(repeats, location_count).sum(1) + total_scores[1:] += scores + return total_scores + +def count_segments(segmenter, loader, model): + total_bincount = 0 + data_size = 0 + progress = default_progress() + for i, batch in enumerate(progress(loader)): + tensor_images = model(z_batch.to(device)) + seg = segmenter.segment_batch(tensor_images, downsample=2) + bc = (seg + index[:, None, None, None] * self.num_classes).view(-1 + ).bincount(minlength=z_batch.shape[0] * self.num_classes) + data_size += seg.shape[0] * seg.shape[2] * seg.shape[3] + total_bincount += batch_label_counts.float().sum(0) + normalized_bincount = total_bincount / data_size + return normalized_bincount + +if __name__ == '__main__': + main() diff --git a/netdissect/fullablate.py b/netdissect/fullablate.py new file mode 100644 index 0000000000000000000000000000000000000000..f92d2c514c0b92b3f33653c5b53198c9fd09cb80 --- /dev/null +++ b/netdissect/fullablate.py @@ -0,0 +1,235 @@ +import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL +from torchvision import transforms +from torch.utils.data import TensorDataset +from netdissect.progress import default_progress, post_progress, desc_progress +from netdissect.progress import verbose_progress, print_progress +from netdissect.nethook import edit_layers +from netdissect.zdataset import standard_z_sample +from netdissect.autoeval import autoimport_eval +from netdissect.easydict import EasyDict +from netdissect.modelconfig import create_instrumented_model + +help_epilog = '''\ +Example: + +python -m netdissect.evalablate \ + --segmenter "netdissect.GanImageSegmenter(segvocab='lowres', segsizes=[160,288], segdiv='quad')" \ + --model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \ + --outdir dissect/dissectdir \ + --classname tree \ + --layer layer4 \ + --size 1000 + +Output layout: +dissectdir/layer5/ablation/mirror-iqr.json +{ class: "mirror", + classnum: 43, + pixel_total: 41342300, + class_pixels: 1234531, + layer: "layer5", + ranking: "mirror-iqr", + ablation_units: [341, 23, 12, 142, 83, ...] + ablation_pixels: [143242, 132344, 429931, ...] +} + +''' + +def main(): + # Training settings + def strpair(arg): + p = tuple(arg.split(':')) + if len(p) == 1: + p = p + p + return p + + parser = argparse.ArgumentParser(description='Ablation eval', + epilog=textwrap.dedent(help_epilog), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument('--model', type=str, default=None, + help='constructor for the model to test') + parser.add_argument('--pthfile', type=str, default=None, + help='filename of .pth file for the model') + parser.add_argument('--outdir', type=str, default='dissect', required=True, + help='directory for dissection output') + parser.add_argument('--layer', type=strpair, + help='space-separated list of layer names to edit' + + ', in the form layername[:reportedname]') + parser.add_argument('--classname', type=str, + help='class name to ablate') + parser.add_argument('--metric', type=str, default='iou', + help='ordering metric for selecting units') + parser.add_argument('--unitcount', type=int, default=30, + help='number of units to ablate') + parser.add_argument('--segmenter', type=str, + help='directory containing segmentation dataset') + parser.add_argument('--netname', type=str, default=None, + help='name for network in generated reports') + parser.add_argument('--batch_size', type=int, default=25, + help='batch size for forward pass') + parser.add_argument('--mixed_units', action='store_true', default=False, + help='true to keep alpha for non-zeroed units') + parser.add_argument('--size', type=int, default=200, + help='number of images to test') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA usage') + parser.add_argument('--quiet', action='store_true', default=False, + help='silences console output') + if len(sys.argv) == 1: + parser.print_usage(sys.stderr) + sys.exit(1) + args = parser.parse_args() + + # Set up console output + verbose_progress(not args.quiet) + + # Speed up pytorch + torch.backends.cudnn.benchmark = True + + # Set up CUDA + args.cuda = not args.no_cuda and torch.cuda.is_available() + if args.cuda: + torch.backends.cudnn.benchmark = True + + # Take defaults for model constructor etc from dissect.json settings. + with open(os.path.join(args.outdir, 'dissect.json')) as f: + dissection = EasyDict(json.load(f)) + if args.model is None: + args.model = dissection.settings.model + if args.pthfile is None: + args.pthfile = dissection.settings.pthfile + if args.segmenter is None: + args.segmenter = dissection.settings.segmenter + if args.layer is None: + args.layer = dissection.settings.layers[0] + args.layers = [args.layer] + + # Also load specific analysis + layername = args.layer[1] + if args.metric == 'iou': + summary = dissection + else: + with open(os.path.join(args.outdir, layername, args.metric, + args.classname, 'summary.json')) as f: + summary = EasyDict(json.load(f)) + + # Instantiate generator + model = create_instrumented_model(args, gen=True, edit=True) + if model is None: + print('No model specified') + sys.exit(1) + + # Instantiate model + device = next(model.parameters()).device + input_shape = model.input_shape + + # 4d input if convolutional, 2d input if first layer is linear. + raw_sample = standard_z_sample(args.size, input_shape[1], seed=3).view( + (args.size,) + input_shape[1:]) + dataset = TensorDataset(raw_sample) + + # Create the segmenter + segmenter = autoimport_eval(args.segmenter) + + # Now do the actual work. + labelnames, catnames = ( + segmenter.get_label_and_category_names(dataset)) + label_category = [catnames.index(c) if c in catnames else 0 + for l, c in labelnames] + labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)} + + segloader = torch.utils.data.DataLoader(dataset, + batch_size=args.batch_size, num_workers=10, + pin_memory=(device.type == 'cuda')) + + # Index the dissection layers by layer name. + + # First, collect a baseline + for l in model.ablation: + model.ablation[l] = None + + # For each sort-order, do an ablation + progress = default_progress() + classname = args.classname + classnum = labelnum_from_name[classname] + + # Get iou ranking from dissect.json + iou_rankname = '%s-%s' % (classname, 'iou') + dissect_layer = {lrec.layer: lrec for lrec in dissection.layers} + iou_ranking = next(r for r in dissect_layer[layername].rankings + if r.name == iou_rankname) + + # Get trained ranking from summary.json + rankname = '%s-%s' % (classname, args.metric) + summary_layer = {lrec.layer: lrec for lrec in summary.layers} + ranking = next(r for r in summary_layer[layername].rankings + if r.name == rankname) + + # Get ordering, first by ranking, then break ties by iou. + ordering = [t[2] for t in sorted([(s1, s2, i) + for i, (s1, s2) in enumerate(zip(ranking.score, iou_ranking.score))])] + values = (-numpy.array(ranking.score))[ordering] + if not args.mixed_units: + values[...] = 1 + + ablationdir = os.path.join(args.outdir, layername, 'fullablation') + measurements = measure_full_ablation(segmenter, segloader, + model, classnum, layername, + ordering[:args.unitcount], values[:args.unitcount]) + measurements = measurements.cpu().numpy().tolist() + os.makedirs(ablationdir, exist_ok=True) + with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f: + json.dump(dict( + classname=classname, + classnum=classnum, + baseline=measurements[0], + layer=layername, + metric=args.metric, + ablation_units=ordering, + ablation_values=values.tolist(), + ablation_effects=measurements[1:]), f) + +def measure_full_ablation(segmenter, loader, model, classnum, layer, + ordering, values): + ''' + Quick and easy counting of segmented pixels reduced by ablating units. + ''' + progress = default_progress() + device = next(model.parameters()).device + feature_units = model.feature_shape[layer][1] + feature_shape = model.feature_shape[layer][2:] + repeats = len(ordering) + total_scores = torch.zeros(repeats + 1) + print(ordering) + print(values.tolist()) + with torch.no_grad(): + for l in model.ablation: + model.ablation[l] = None + for i, [ibz] in enumerate(progress(loader)): + ibz = ibz.cuda() + for num_units in progress(range(len(ordering) + 1)): + ablation = torch.zeros(feature_units, device=device) + ablation[ordering[:num_units]] = torch.tensor( + values[:num_units]).to(ablation.device, ablation.dtype) + model.ablation[layer] = ablation + tensor_images = model(ibz) + seg = segmenter.segment_batch(tensor_images, downsample=2) + mask = (seg == classnum).max(1)[0] + total_scores[num_units] += mask.sum().float().cpu() + return total_scores + +def count_segments(segmenter, loader, model): + total_bincount = 0 + data_size = 0 + progress = default_progress() + for i, batch in enumerate(progress(loader)): + tensor_images = model(z_batch.to(device)) + seg = segmenter.segment_batch(tensor_images, downsample=2) + bc = (seg + index[:, None, None, None] * self.num_classes).view(-1 + ).bincount(minlength=z_batch.shape[0] * self.num_classes) + data_size += seg.shape[0] * seg.shape[2] * seg.shape[3] + total_bincount += batch_label_counts.float().sum(0) + normalized_bincount = total_bincount / data_size + return normalized_bincount + +if __name__ == '__main__': + main() diff --git a/netdissect/modelconfig.py b/netdissect/modelconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ee37a809ea1bcbd803cd7d4e100e1bb93290c9 --- /dev/null +++ b/netdissect/modelconfig.py @@ -0,0 +1,144 @@ +''' +Original from https://github.com/CSAILVision/GANDissect +Modified by Erik Härkönen, 29.11.2019 +''' + +import numbers +import torch +from netdissect.autoeval import autoimport_eval +from netdissect.progress import print_progress +from netdissect.nethook import InstrumentedModel +from netdissect.easydict import EasyDict + +def create_instrumented_model(args, **kwargs): + ''' + Creates an instrumented model out of a namespace of arguments that + correspond to ArgumentParser command-line args: + model: a string to evaluate as a constructor for the model. + pthfile: (optional) filename of .pth file for the model. + layers: a list of layers to instrument, defaulted if not provided. + edit: True to instrument the layers for editing. + gen: True for a generator model. One-pixel input assumed. + imgsize: For non-generator models, (y, x) dimensions for RGB input. + cuda: True to use CUDA. + + The constructed model will be decorated with the following attributes: + input_shape: (usually 4d) tensor shape for single-image input. + output_shape: 4d tensor shape for output. + feature_shape: map of layer names to 4d tensor shape for featuremaps. + retained: map of layernames to tensors, filled after every evaluation. + ablation: if editing, map of layernames to [0..1] alpha values to fill. + replacement: if editing, map of layernames to values to fill. + + When editing, the feature value x will be replaced by: + `x = (replacement * ablation) + (x * (1 - ablation))` + ''' + + args = EasyDict(vars(args), **kwargs) + + # Construct the network + if args.model is None: + print_progress('No model specified') + return None + if isinstance(args.model, torch.nn.Module): + model = args.model + else: + model = autoimport_eval(args.model) + # Unwrap any DataParallel-wrapped model + if isinstance(model, torch.nn.DataParallel): + model = next(model.children()) + + # Load its state dict + meta = {} + if getattr(args, 'pthfile', None) is not None: + data = torch.load(args.pthfile) + if 'state_dict' in data: + meta = {} + for key in data: + if isinstance(data[key], numbers.Number): + meta[key] = data[key] + data = data['state_dict'] + submodule = getattr(args, 'submodule', None) + if submodule is not None and len(submodule): + remove_prefix = submodule + '.' + data = { k[len(remove_prefix):]: v for k, v in data.items() + if k.startswith(remove_prefix)} + if not len(data): + print_progress('No submodule %s found in %s' % + (submodule, args.pthfile)) + return None + model.load_state_dict(data, strict=not getattr(args, 'unstrict', False)) + + # Decide which layers to instrument. + if getattr(args, 'layer', None) is not None: + args.layers = [args.layer] + if getattr(args, 'layers', None) is None: + # Skip wrappers with only one named model + container = model + prefix = '' + while len(list(container.named_children())) == 1: + name, container = next(container.named_children()) + prefix += name + '.' + # Default to all nontrivial top-level layers except last. + args.layers = [prefix + name + for name, module in container.named_children() + if type(module).__module__ not in [ + # Skip ReLU and other activations. + 'torch.nn.modules.activation', + # Skip pooling layers. + 'torch.nn.modules.pooling'] + ][:-1] + print_progress('Defaulting to layers: %s' % ' '.join(args.layers)) + + # Now wrap the model for instrumentation. + model = InstrumentedModel(model) + model.meta = meta + + # Instrument the layers. + model.retain_layers(args.layers) + model.eval() + if args.cuda: + model.cuda() + + # Annotate input, output, and feature shapes + annotate_model_shapes(model, + gen=getattr(args, 'gen', False), + imgsize=getattr(args, 'imgsize', None), + latent_shape=getattr(args, 'latent_shape', None)) + return model + +def annotate_model_shapes(model, gen=False, imgsize=None, latent_shape=None): + assert (imgsize is not None) or gen + + # Figure the input shape. + if gen: + if latent_shape is None: + # We can guess a generator's input shape by looking at the model. + # Examine first conv in model to determine input feature size. + first_layer = [c for c in model.modules() + if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, + torch.nn.Linear))][0] + # 4d input if convolutional, 2d input if first layer is linear. + if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + input_shape = (1, first_layer.in_channels, 1, 1) + else: + input_shape = (1, first_layer.in_features) + else: + # Specify input shape manually + input_shape = latent_shape + else: + # For a classifier, the input image shape is given as an argument. + input_shape = (1, 3) + tuple(imgsize) + + # Run the model once to observe feature shapes. + device = next(model.parameters()).device + dry_run = torch.zeros(input_shape).to(device) + with torch.no_grad(): + output = model(dry_run) + + # Annotate shapes. + model.input_shape = input_shape + model.feature_shape = { layer: feature.shape + for layer, feature in model.retained_features().items() } + model.output_shape = output.shape + return model diff --git a/netdissect/nethook.py b/netdissect/nethook.py new file mode 100644 index 0000000000000000000000000000000000000000..f36e84ee0cae2de2c3be247498408cf66db3ee8f --- /dev/null +++ b/netdissect/nethook.py @@ -0,0 +1,266 @@ +''' +Utilities for instrumenting a torch model. + +InstrumentedModel will wrap a pytorch model and allow hooking +arbitrary layers to monitor or modify their output directly. + +Modified by Erik Härkönen: +- 29.11.2019: Unhooking bugfix +- 25.01.2020: Offset edits, removed old API +''' + +import torch, numpy, types +from collections import OrderedDict + +class InstrumentedModel(torch.nn.Module): + ''' + A wrapper for hooking, probing and intervening in pytorch Modules. + Example usage: + + ``` + model = load_my_model() + with inst as InstrumentedModel(model): + inst.retain_layer(layername) + inst.edit_layer(layername, 0.5, target_features) + inst.edit_layer(layername, offset=offset_tensor) + inst(inputs) + original_features = inst.retained_layer(layername) + ``` + ''' + + def __init__(self, model): + super(InstrumentedModel, self).__init__() + self.model = model + self._retained = OrderedDict() + self._ablation = {} + self._replacement = {} + self._offset = {} + self._hooked_layer = {} + self._old_forward = {} + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + + def forward(self, *inputs, **kwargs): + return self.model(*inputs, **kwargs) + + def retain_layer(self, layername): + ''' + Pass a fully-qualified layer name (E.g., module.submodule.conv3) + to hook that layer and retain its output each time the model is run. + A pair (layername, aka) can be provided, and the aka will be used + as the key for the retained value instead of the layername. + ''' + self.retain_layers([layername]) + + def retain_layers(self, layernames): + ''' + Retains a list of a layers at once. + ''' + self.add_hooks(layernames) + for layername in layernames: + aka = layername + if not isinstance(aka, str): + layername, aka = layername + if aka not in self._retained: + self._retained[aka] = None + + def retained_features(self): + ''' + Returns a dict of all currently retained features. + ''' + return OrderedDict(self._retained) + + def retained_layer(self, aka=None, clear=False): + ''' + Retrieve retained data that was previously hooked by retain_layer. + Call this after the model is run. If clear is set, then the + retained value will return and also cleared. + ''' + if aka is None: + # Default to the first retained layer. + aka = next(self._retained.keys().__iter__()) + result = self._retained[aka] + if clear: + self._retained[aka] = None + return result + + def edit_layer(self, layername, ablation=None, replacement=None, offset=None): + ''' + Pass a fully-qualified layer name (E.g., module.submodule.conv3) + to hook that layer and modify its output each time the model is run. + The output of the layer will be modified to be a convex combination + of the replacement and x interpolated according to the ablation, i.e.: + `output = x * (1 - a) + (r * a)`. + Additionally or independently, an offset can be added to the output. + ''' + if not isinstance(layername, str): + layername, aka = layername + else: + aka = layername + + # The default ablation if a replacement is specified is 1.0. + if ablation is None and replacement is not None: + ablation = 1.0 + self.add_hooks([(layername, aka)]) + if ablation is not None: + self._ablation[aka] = ablation + if replacement is not None: + self._replacement[aka] = replacement + if offset is not None: + self._offset[aka] = offset + # If needed, could add an arbitrary postprocessing lambda here. + + def remove_edits(self, layername=None, remove_offset=True, remove_replacement=True): + ''' + Removes edits at the specified layer, or removes edits at all layers + if no layer name is specified. + ''' + if layername is None: + if remove_replacement: + self._ablation.clear() + self._replacement.clear() + if remove_offset: + self._offset.clear() + return + + if not isinstance(layername, str): + layername, aka = layername + else: + aka = layername + if remove_replacement and aka in self._ablation: + del self._ablation[aka] + if remove_replacement and aka in self._replacement: + del self._replacement[aka] + if remove_offset and aka in self._offset: + del self._offset[aka] + + def add_hooks(self, layernames): + ''' + Sets up a set of layers to be hooked. + + Usually not called directly: use edit_layer or retain_layer instead. + ''' + needed = set() + aka_map = {} + for name in layernames: + aka = name + if not isinstance(aka, str): + name, aka = name + if self._hooked_layer.get(aka, None) != name: + aka_map[name] = aka + needed.add(name) + if not needed: + return + for name, layer in self.model.named_modules(): + if name in aka_map: + needed.remove(name) + aka = aka_map[name] + self._hook_layer(layer, name, aka) + for name in needed: + raise ValueError('Layer %s not found in model' % name) + + def _hook_layer(self, layer, layername, aka): + ''' + Internal method to replace a forward method with a closure that + intercepts the call, and tracks the hook so that it can be reverted. + ''' + if aka in self._hooked_layer: + raise ValueError('Layer %s already hooked' % aka) + if layername in self._old_forward: + raise ValueError('Layer %s already hooked' % layername) + self._hooked_layer[aka] = layername + self._old_forward[layername] = (layer, aka, + layer.__dict__.get('forward', None)) + editor = self + original_forward = layer.forward + def new_forward(self, *inputs, **kwargs): + original_x = original_forward(*inputs, **kwargs) + x = editor._postprocess_forward(original_x, aka) + return x + layer.forward = types.MethodType(new_forward, layer) + + def _unhook_layer(self, aka): + ''' + Internal method to remove a hook, restoring the original forward method. + ''' + if aka not in self._hooked_layer: + return + layername = self._hooked_layer[aka] + layer, check, old_forward = self._old_forward[layername] + assert check == aka + if old_forward is None: + if 'forward' in layer.__dict__: + del layer.__dict__['forward'] + else: + layer.forward = old_forward + del self._old_forward[layername] + del self._hooked_layer[aka] + if aka in self._ablation: + del self._ablation[aka] + if aka in self._replacement: + del self._replacement[aka] + if aka in self._offset: + del self._offset[aka] + if aka in self._retained: + del self._retained[aka] + + def _postprocess_forward(self, x, aka): + ''' + The internal method called by the hooked layers after they are run. + ''' + # Retain output before edits, if desired. + if aka in self._retained: + self._retained[aka] = x.detach() + + # Apply replacement edit + a = make_matching_tensor(self._ablation, aka, x) + if a is not None: + x = x * (1 - a) + v = make_matching_tensor(self._replacement, aka, x) + if v is not None: + x += (v * a) + + # Apply offset edit + b = make_matching_tensor(self._offset, aka, x) + if b is not None: + x = x + b + + return x + + def close(self): + ''' + Unhooks all hooked layers in the model. + ''' + for aka in list(self._old_forward.keys()): + self._unhook_layer(aka) + assert len(self._old_forward) == 0 + + +def make_matching_tensor(valuedict, name, data): + ''' + Converts `valuedict[name]` to be a tensor with the same dtype, device, + and dimension count as `data`, and caches the converted tensor. + ''' + v = valuedict.get(name, None) + if v is None: + return None + if not isinstance(v, torch.Tensor): + # Accept non-torch data. + v = torch.from_numpy(numpy.array(v)) + valuedict[name] = v + if not v.device == data.device or not v.dtype == data.dtype: + # Ensure device and type matches. + assert not v.requires_grad, '%s wrong device or type' % (name) + v = v.to(device=data.device, dtype=data.dtype) + valuedict[name] = v + if len(v.shape) < len(data.shape): + # Ensure dimensions are unsqueezed as needed. + assert not v.requires_grad, '%s wrong dimensions' % (name) + v = v.view((1,) + tuple(v.shape) + + (1,) * (len(data.shape) - len(v.shape) - 1)) + valuedict[name] = v + return v diff --git a/netdissect/parallelfolder.py b/netdissect/parallelfolder.py new file mode 100644 index 0000000000000000000000000000000000000000..a741691569a7c85e96d3b3d9be12b40d508f0044 --- /dev/null +++ b/netdissect/parallelfolder.py @@ -0,0 +1,118 @@ +''' +Variants of pytorch's ImageFolder for loading image datasets with more +information, such as parallel feature channels in separate files, +cached files with lists of filenames, etc. +''' + +import os, torch, re +import torch.utils.data as data +from torchvision.datasets.folder import default_loader +from PIL import Image +from collections import OrderedDict +from .progress import default_progress + +def grayscale_loader(path): + with open(path, 'rb') as f: + return Image.open(f).convert('L') + +class ParallelImageFolders(data.Dataset): + """ + A data loader that looks for parallel image filenames, for example + + photo1/park/004234.jpg + photo1/park/004236.jpg + photo1/park/004237.jpg + + photo2/park/004234.png + photo2/park/004236.png + photo2/park/004237.png + """ + def __init__(self, image_roots, + transform=None, + loader=default_loader, + stacker=None, + intersection=False, + verbose=None, + size=None): + self.image_roots = image_roots + self.images = make_parallel_dataset(image_roots, + intersection=intersection, verbose=verbose) + if len(self.images) == 0: + raise RuntimeError("Found 0 images within: %s" % image_roots) + if size is not None: + self.image = self.images[:size] + if transform is not None and not hasattr(transform, '__iter__'): + transform = [transform for _ in image_roots] + self.transforms = transform + self.stacker = stacker + self.loader = loader + + def __getitem__(self, index): + paths = self.images[index] + sources = [self.loader(path) for path in paths] + # Add a common shared state dict to allow random crops/flips to be + # coordinated. + shared_state = {} + for s in sources: + s.shared_state = shared_state + if self.transforms is not None: + sources = [transform(source) + for source, transform in zip(sources, self.transforms)] + if self.stacker is not None: + sources = self.stacker(sources) + else: + sources = tuple(sources) + return sources + + def __len__(self): + return len(self.images) + +def is_npy_file(path): + return path.endswith('.npy') or path.endswith('.NPY') + +def is_image_file(path): + return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE) + +def walk_image_files(rootdir, verbose=None): + progress = default_progress(verbose) + indexfile = '%s.txt' % rootdir + if os.path.isfile(indexfile): + basedir = os.path.dirname(rootdir) + with open(indexfile) as f: + result = sorted([os.path.join(basedir, line.strip()) + for line in progress(f.readlines(), + desc='Reading %s' % os.path.basename(indexfile))]) + return result + result = [] + for dirname, _, fnames in sorted(progress(os.walk(rootdir), + desc='Walking %s' % os.path.basename(rootdir))): + for fname in sorted(fnames): + if is_image_file(fname) or is_npy_file(fname): + result.append(os.path.join(dirname, fname)) + return result + +def make_parallel_dataset(image_roots, intersection=False, verbose=None): + """ + Returns [(img1, img2), (img1, img2)..] + """ + image_roots = [os.path.expanduser(d) for d in image_roots] + image_sets = OrderedDict() + for j, root in enumerate(image_roots): + for path in walk_image_files(root, verbose=verbose): + key = os.path.splitext(os.path.relpath(path, root))[0] + if key not in image_sets: + image_sets[key] = [] + if not intersection and len(image_sets[key]) != j: + raise RuntimeError( + 'Images not parallel: %s missing from one dir' % (key)) + image_sets[key].append(path) + tuples = [] + for key, value in image_sets.items(): + if len(value) != len(image_roots): + if intersection: + continue + else: + raise RuntimeError( + 'Images not parallel: %s missing from one dir' % (key)) + tuples.append(tuple(value)) + return tuples diff --git a/netdissect/pidfile.py b/netdissect/pidfile.py new file mode 100644 index 0000000000000000000000000000000000000000..96a66814326bad444606ad829307fe225f4135e1 --- /dev/null +++ b/netdissect/pidfile.py @@ -0,0 +1,81 @@ +''' +Utility for simple distribution of work on multiple processes, by +making sure only one process is working on a job at once. +''' + +import os, errno, socket, atexit, time, sys + +def exit_if_job_done(directory): + if pidfile_taken(os.path.join(directory, 'lockfile.pid'), verbose=True): + sys.exit(0) + if os.path.isfile(os.path.join(directory, 'done.txt')): + with open(os.path.join(directory, 'done.txt')) as f: + msg = f.read() + print(msg) + sys.exit(0) + +def mark_job_done(directory): + with open(os.path.join(directory, 'done.txt'), 'w') as f: + f.write('Done by %d@%s %s at %s' % + (os.getpid(), socket.gethostname(), + os.getenv('STY', ''), + time.strftime('%c'))) + +def pidfile_taken(path, verbose=False): + ''' + Usage. To grab an exclusive lock for the remaining duration of the + current process (and exit if another process already has the lock), + do this: + + if pidfile_taken('job_423/lockfile.pid', verbose=True): + sys.exit(0) + + To do a batch of jobs, just run a script that does them all on + each available machine, sharing a network filesystem. When each + job grabs a lock, then this will automatically distribute the + jobs so that each one is done just once on one machine. + ''' + + # Try to create the file exclusively and write my pid into it. + try: + os.makedirs(os.path.dirname(path), exist_ok=True) + fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR) + except OSError as e: + if e.errno == errno.EEXIST: + # If we cannot because there was a race, yield the conflicter. + conflicter = 'race' + try: + with open(path, 'r') as lockfile: + conflicter = lockfile.read().strip() or 'empty' + except: + pass + if verbose: + print('%s held by %s' % (path, conflicter)) + return conflicter + else: + # Other problems get an exception. + raise + # Register to delete this file on exit. + lockfile = os.fdopen(fd, 'r+') + atexit.register(delete_pidfile, lockfile, path) + # Write my pid into the open file. + lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(), + os.getenv('STY', ''))) + lockfile.flush() + os.fsync(lockfile) + # Return 'None' to say there was not a conflict. + return None + +def delete_pidfile(lockfile, path): + ''' + Runs at exit after pidfile_taken succeeds. + ''' + if lockfile is not None: + try: + lockfile.close() + except: + pass + try: + os.unlink(path) + except: + pass diff --git a/netdissect/plotutil.py b/netdissect/plotutil.py new file mode 100644 index 0000000000000000000000000000000000000000..187bcb9d5615c8ec51a43148b011c06b8ed6aff7 --- /dev/null +++ b/netdissect/plotutil.py @@ -0,0 +1,61 @@ +import matplotlib.pyplot as plt +import numpy + +def plot_tensor_images(data, **kwargs): + data = ((data + 1) / 2 * 255).permute(0, 2, 3, 1).byte().cpu().numpy() + width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) + height = int(numpy.ceil(data.shape[0] / float(width))) + kwargs = dict(kwargs) + margin = 0.01 + if 'figsize' not in kwargs: + # Size figure to one display pixel per data pixel + dpi = plt.rcParams['figure.dpi'] + kwargs['figsize'] = ( + (1 + margin) * (width * data.shape[2] / dpi), + (1 + margin) * (height * data.shape[1] / dpi)) + f, axarr = plt.subplots(height, width, **kwargs) + if len(numpy.shape(axarr)) == 0: + axarr = numpy.array([[axarr]]) + if len(numpy.shape(axarr)) == 1: + axarr = axarr[None,:] + for i, im in enumerate(data): + ax = axarr[i // width, i % width] + ax.imshow(data[i]) + ax.axis('off') + for i in range(i, width * height): + ax = axarr[i // width, i % width] + ax.axis('off') + plt.subplots_adjust(wspace=margin, hspace=margin, + left=0, right=1, bottom=0, top=1) + plt.show() + +def plot_max_heatmap(data, shape=None, **kwargs): + if shape is None: + shape = data.shape[2:] + data = data.max(1)[0].cpu().numpy() + vmin = data.min() + vmax = data.max() + width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) + height = int(numpy.ceil(data.shape[0] / float(width))) + kwargs = dict(kwargs) + margin = 0.01 + if 'figsize' not in kwargs: + # Size figure to one display pixel per data pixel + dpi = plt.rcParams['figure.dpi'] + kwargs['figsize'] = ( + width * shape[1] / dpi, height * shape[0] / dpi) + f, axarr = plt.subplots(height, width, **kwargs) + if len(numpy.shape(axarr)) == 0: + axarr = numpy.array([[axarr]]) + if len(numpy.shape(axarr)) == 1: + axarr = axarr[None,:] + for i, im in enumerate(data): + ax = axarr[i // width, i % width] + img = ax.imshow(data[i], vmin=vmin, vmax=vmax, cmap='hot') + ax.axis('off') + for i in range(i, width * height): + ax = axarr[i // width, i % width] + ax.axis('off') + plt.subplots_adjust(wspace=margin, hspace=margin, + left=0, right=1, bottom=0, top=1) + plt.show() diff --git a/netdissect/proggan.py b/netdissect/proggan.py new file mode 100644 index 0000000000000000000000000000000000000000..e37ae15f373ef6ad14279bb581042434c5563539 --- /dev/null +++ b/netdissect/proggan.py @@ -0,0 +1,299 @@ +import torch, numpy, itertools +import torch.nn as nn +from collections import OrderedDict + + +def print_network(net, verbose=False): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('Total number of parameters: {:3.3f} M'.format(num_params / 1e6)) + + +def from_pth_file(filename): + ''' + Instantiate from a pth file. + ''' + state_dict = torch.load(filename) + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + # Convert old version of parameter names + if 'features.0.conv.weight' in state_dict: + state_dict = state_dict_from_old_pt_dict(state_dict) + sizes = sizes_from_state_dict(state_dict) + result = ProgressiveGenerator(sizes=sizes) + result.load_state_dict(state_dict) + return result + +############################################################################### +# Modules +############################################################################### + +class ProgressiveGenerator(nn.Sequential): + def __init__(self, resolution=None, sizes=None, modify_sequence=None, + output_tanh=False): + ''' + A pytorch progessive GAN generator that can be converted directly + from either a tensorflow model or a theano model. It consists of + a sequence of convolutional layers, organized in pairs, with an + upsampling and reduction of channels at every other layer; and + then finally followed by an output layer that reduces it to an + RGB [-1..1] image. + + The network can be given more layers to increase the output + resolution. The sizes argument indicates the fieature depth at + each upsampling, starting with the input z: [input-dim, 4x4-depth, + 8x8-depth, 16x16-depth...]. The output dimension is 2 * 2**len(sizes) + + Some default architectures can be selected by supplying the + resolution argument instead. + + The optional modify_sequence function can be used to transform the + sequence of layers before the network is constructed. + + If output_tanh is set to True, the network applies a tanh to clamp + the output to [-1,1] before output; otherwise the output is unclamped. + ''' + assert (resolution is None) != (sizes is None) + if sizes is None: + sizes = { + 8: [512, 512, 512], + 16: [512, 512, 512, 512], + 32: [512, 512, 512, 512, 256], + 64: [512, 512, 512, 512, 256, 128], + 128: [512, 512, 512, 512, 256, 128, 64], + 256: [512, 512, 512, 512, 256, 128, 64, 32], + 1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16] + }[resolution] + # Follow the schedule of upsampling given by sizes. + # layers are called: layer1, layer2, etc; then output_128x128 + sequence = [] + def add_d(layer, name=None): + if name is None: + name = 'layer%d' % (len(sequence) + 1) + sequence.append((name, layer)) + add_d(NormConvBlock(sizes[0], sizes[1], kernel_size=4, padding=3)) + add_d(NormConvBlock(sizes[1], sizes[1], kernel_size=3, padding=1)) + for i, (si, so) in enumerate(zip(sizes[1:-1], sizes[2:])): + add_d(NormUpscaleConvBlock(si, so, kernel_size=3, padding=1)) + add_d(NormConvBlock(so, so, kernel_size=3, padding=1)) + # Create an output layer. During training, the progressive GAN + # learns several such output layers for various resolutions; we + # just include the last (highest resolution) one. + dim = 4 * (2 ** (len(sequence) // 2 - 1)) + add_d(OutputConvBlock(sizes[-1], tanh=output_tanh), + name='output_%dx%d' % (dim, dim)) + # Allow the sequence to be modified + if modify_sequence is not None: + sequence = modify_sequence(sequence) + super().__init__(OrderedDict(sequence)) + + def forward(self, x): + # Convert vector input to 1x1 featuremap. + x = x.view(x.shape[0], x.shape[1], 1, 1) + return super().forward(x) + +class PixelNormLayer(nn.Module): + def __init__(self): + super(PixelNormLayer, self).__init__() + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + +class DoubleResolutionLayer(nn.Module): + def forward(self, x): + x = nn.functional.interpolate(x, scale_factor=2, mode='nearest') + return x + +class WScaleLayer(nn.Module): + def __init__(self, size, fan_in, gain=numpy.sqrt(2)): + super(WScaleLayer, self).__init__() + self.scale = gain / numpy.sqrt(fan_in) # No longer a parameter + self.b = nn.Parameter(torch.randn(size)) + self.size = size + + def forward(self, x): + x_size = x.size() + x = x * self.scale + self.b.view(1, -1, 1, 1).expand( + x_size[0], self.size, x_size[2], x_size[3]) + return x + +class NormConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding): + super(NormConvBlock, self).__init__() + self.norm = PixelNormLayer() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, 1, padding, bias=False) + self.wscale = WScaleLayer(out_channels, in_channels, + gain=numpy.sqrt(2) / kernel_size) + self.relu = nn.LeakyReLU(inplace=True, negative_slope=0.2) + + def forward(self, x): + x = self.norm(x) + x = self.conv(x) + x = self.relu(self.wscale(x)) + return x + +class NormUpscaleConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding): + super(NormUpscaleConvBlock, self).__init__() + self.norm = PixelNormLayer() + self.up = DoubleResolutionLayer() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, 1, padding, bias=False) + self.wscale = WScaleLayer(out_channels, in_channels, + gain=numpy.sqrt(2) / kernel_size) + self.relu = nn.LeakyReLU(inplace=True, negative_slope=0.2) + + def forward(self, x): + x = self.norm(x) + x = self.up(x) + x = self.conv(x) + x = self.relu(self.wscale(x)) + return x + +class OutputConvBlock(nn.Module): + def __init__(self, in_channels, tanh=False): + super().__init__() + self.norm = PixelNormLayer() + self.conv = nn.Conv2d( + in_channels, 3, kernel_size=1, padding=0, bias=False) + self.wscale = WScaleLayer(3, in_channels, gain=1) + self.clamp = nn.Hardtanh() if tanh else (lambda x: x) + + def forward(self, x): + x = self.norm(x) + x = self.conv(x) + x = self.wscale(x) + x = self.clamp(x) + return x + +############################################################################### +# Conversion +############################################################################### + +def from_tf_parameters(parameters): + ''' + Instantiate from tensorflow variables. + ''' + state_dict = state_dict_from_tf_parameters(parameters) + sizes = sizes_from_state_dict(state_dict) + result = ProgressiveGenerator(sizes=sizes) + result.load_state_dict(state_dict) + return result + +def from_old_pt_dict(parameters): + ''' + Instantiate from old pytorch state dict. + ''' + state_dict = state_dict_from_old_pt_dict(parameters) + sizes = sizes_from_state_dict(state_dict) + result = ProgressiveGenerator(sizes=sizes) + result.load_state_dict(state_dict) + return result + +def sizes_from_state_dict(params): + ''' + In a progressive GAN, the number of channels can change after each + upsampling. This function reads the state dict to figure the + number of upsamplings and the channel depth of each filter. + ''' + sizes = [] + for i in itertools.count(): + pt_layername = 'layer%d' % (i + 1) + try: + weight = params['%s.conv.weight' % pt_layername] + except KeyError: + break + if i == 0: + sizes.append(weight.shape[1]) + if i % 2 == 0: + sizes.append(weight.shape[0]) + return sizes + +def state_dict_from_tf_parameters(parameters): + ''' + Conversion from tensorflow parameters + ''' + def torch_from_tf(data): + return torch.from_numpy(data.eval()) + + params = dict(parameters) + result = {} + sizes = [] + for i in itertools.count(): + resolution = 4 * (2 ** (i // 2)) + # Translate parameter names. For example: + # 4x4/Dense/weight -> layer1.conv.weight + # 32x32/Conv0_up/weight -> layer7.conv.weight + # 32x32/Conv1/weight -> layer8.conv.weight + tf_layername = '%dx%d/%s' % (resolution, resolution, + 'Dense' if i == 0 else 'Conv' if i == 1 else + 'Conv0_up' if i % 2 == 0 else 'Conv1') + pt_layername = 'layer%d' % (i + 1) + # Stop looping when we run out of parameters. + try: + weight = torch_from_tf(params['%s/weight' % tf_layername]) + except KeyError: + break + # Transpose convolution weights into pytorch format. + if i == 0: + # Convert dense layer to 4x4 convolution + weight = weight.view(weight.shape[0], weight.shape[1] // 16, + 4, 4).permute(1, 0, 2, 3).flip(2, 3) + sizes.append(weight.shape[0]) + elif i % 2 == 0: + # Convert inverse convolution to convolution + weight = weight.permute(2, 3, 0, 1).flip(2, 3) + else: + # Ordinary Conv2d conversion. + weight = weight.permute(3, 2, 0, 1) + sizes.append(weight.shape[1]) + result['%s.conv.weight' % (pt_layername)] = weight + # Copy bias vector. + bias = torch_from_tf(params['%s/bias' % tf_layername]) + result['%s.wscale.b' % (pt_layername)] = bias + # Copy just finest-grained ToRGB output layers. For example: + # ToRGB_lod0/weight -> output.conv.weight + i -= 1 + resolution = 4 * (2 ** (i // 2)) + tf_layername = 'ToRGB_lod0' + pt_layername = 'output_%dx%d' % (resolution, resolution) + result['%s.conv.weight' % pt_layername] = torch_from_tf( + params['%s/weight' % tf_layername]).permute(3, 2, 0, 1) + result['%s.wscale.b' % pt_layername] = torch_from_tf( + params['%s/bias' % tf_layername]) + # Return parameters + return result + +def state_dict_from_old_pt_dict(params): + ''' + Conversion from the old pytorch model layer names. + ''' + result = {} + sizes = [] + for i in itertools.count(): + old_layername = 'features.%d' % i + pt_layername = 'layer%d' % (i + 1) + try: + weight = params['%s.conv.weight' % (old_layername)] + except KeyError: + break + if i == 0: + sizes.append(weight.shape[0]) + if i % 2 == 0: + sizes.append(weight.shape[1]) + result['%s.conv.weight' % (pt_layername)] = weight + result['%s.wscale.b' % (pt_layername)] = params[ + '%s.wscale.b' % (old_layername)] + # Copy the output layers. + i -= 1 + resolution = 4 * (2 ** (i // 2)) + pt_layername = 'output_%dx%d' % (resolution, resolution) + result['%s.conv.weight' % pt_layername] = params['output.conv.weight'] + result['%s.wscale.b' % pt_layername] = params['output.wscale.b'] + # Return parameters and also network architecture sizes. + return result + diff --git a/netdissect/progress.py b/netdissect/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..702b24cf6668e6caad38d3c315eb658b6af4d230 --- /dev/null +++ b/netdissect/progress.py @@ -0,0 +1,98 @@ +''' +Utilities for showing progress bars, controlling default verbosity, etc. +''' + +# If the tqdm package is not available, then do not show progress bars; +# just connect print_progress to print. +try: + from tqdm import tqdm, tqdm_notebook +except: + tqdm = None + +default_verbosity = False + +def verbose_progress(verbose): + ''' + Sets default verbosity level. Set to True to see progress bars. + ''' + global default_verbosity + default_verbosity = verbose + +def tqdm_terminal(it, *args, **kwargs): + ''' + Some settings for tqdm that make it run better in resizable terminals. + ''' + return tqdm(it, *args, dynamic_ncols=True, ascii=True, + leave=(not nested_tqdm()), **kwargs) + +def in_notebook(): + ''' + True if running inside a Jupyter notebook. + ''' + # From https://stackoverflow.com/a/39662359/265298 + try: + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell': + return True # Jupyter notebook or qtconsole + elif shell == 'TerminalInteractiveShell': + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False # Probably standard Python interpreter + +def nested_tqdm(): + ''' + True if there is an active tqdm progress loop on the stack. + ''' + return hasattr(tqdm, '_instances') and len(tqdm._instances) > 0 + +def post_progress(**kwargs): + ''' + When within a progress loop, post_progress(k=str) will display + the given k=str status on the right-hand-side of the progress + status bar. If not within a visible progress bar, does nothing. + ''' + if nested_tqdm(): + innermost = max(tqdm._instances, key=lambda x: x.pos) + innermost.set_postfix(**kwargs) + +def desc_progress(desc): + ''' + When within a progress loop, desc_progress(str) changes the + left-hand-side description of the loop toe the given description. + ''' + if nested_tqdm(): + innermost = max(tqdm._instances, key=lambda x: x.pos) + innermost.set_description(desc) + +def print_progress(*args): + ''' + When within a progress loop, post_progress(k=str) will display + the given k=str status on the right-hand-side of the progress + status bar. If not within a visible progress bar, does nothing. + ''' + if default_verbosity: + printfn = print if tqdm is None else tqdm.write + printfn(' '.join(str(s) for s in args)) + +def default_progress(verbose=None, iftop=False): + ''' + Returns a progress function that can wrap iterators to print + progress messages, if verbose is True. + + If verbose is False or if iftop is True and there is already + a top-level tqdm loop being reported, then a quiet non-printing + identity function is returned. + + verbose can also be set to a spefific progress function rather + than True, and that function will be used. + ''' + global default_verbosity + if verbose is None: + verbose = default_verbosity + if not verbose or (iftop and nested_tqdm()) or tqdm is None: + return lambda x, *args, **kw: x + if verbose == True: + return tqdm_notebook if in_notebook() else tqdm_terminal + return verbose diff --git a/netdissect/runningstats.py b/netdissect/runningstats.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4093e0318edeecf8aebc34771adbde5043e2d4 --- /dev/null +++ b/netdissect/runningstats.py @@ -0,0 +1,773 @@ +''' +Running statistics on the GPU using pytorch. + +RunningTopK maintains top-k statistics for a set of channels in parallel. +RunningQuantile maintains (sampled) quantile statistics for a set of channels. +''' + +import torch, math, numpy +from collections import defaultdict + +class RunningTopK: + ''' + A class to keep a running tally of the the top k values (and indexes) + of any number of torch feature components. Will work on the GPU if + the data is on the GPU. + + This version flattens all arrays to avoid crashes. + ''' + def __init__(self, k=100, state=None): + if state is not None: + self.set_state_dict(state) + return + self.k = k + self.count = 0 + # This version flattens all data internally to 2-d tensors, + # to avoid crashes with the current pytorch topk implementation. + # The data is puffed back out to arbitrary tensor shapes on ouput. + self.data_shape = None + self.top_data = None + self.top_index = None + self.next = 0 + self.linear_index = 0 + self.perm = None + + def add(self, data): + ''' + Adds a batch of data to be considered for the running top k. + The zeroth dimension enumerates the observations. All other + dimensions enumerate different features. + ''' + if self.top_data is None: + # Allocation: allocate a buffer of size 5*k, at least 10, for each. + self.data_shape = data.shape[1:] + feature_size = int(numpy.prod(self.data_shape)) + self.top_data = torch.zeros( + feature_size, max(10, self.k * 5), out=data.new()) + self.top_index = self.top_data.clone().long() + self.linear_index = 0 if len(data.shape) == 1 else torch.arange( + feature_size, out=self.top_index.new()).mul_( + self.top_data.shape[-1])[:,None] + size = data.shape[0] + sk = min(size, self.k) + if self.top_data.shape[-1] < self.next + sk: + # Compression: if full, keep topk only. + self.top_data[:,:self.k], self.top_index[:,:self.k] = ( + self.result(sorted=False, flat=True)) + self.next = self.k + free = self.top_data.shape[-1] - self.next + # Pick: copy the top sk of the next batch into the buffer. + # Currently strided topk is slow. So we clone after transpose. + # TODO: remove the clone() if it becomes faster. + cdata = data.contiguous().view(size, -1).t().clone() + td, ti = cdata.topk(sk, sorted=False) + self.top_data[:,self.next:self.next+sk] = td + self.top_index[:,self.next:self.next+sk] = (ti + self.count) + self.next += sk + self.count += size + + def result(self, sorted=True, flat=False): + ''' + Returns top k data items and indexes in each dimension, + with channels in the first dimension and k in the last dimension. + ''' + k = min(self.k, self.next) + # bti are top indexes relative to buffer array. + td, bti = self.top_data[:,:self.next].topk(k, sorted=sorted) + # we want to report top indexes globally, which is ti. + ti = self.top_index.view(-1)[ + (bti + self.linear_index).view(-1) + ].view(*bti.shape) + if flat: + return td, ti + else: + return (td.view(*(self.data_shape + (-1,))), + ti.view(*(self.data_shape + (-1,)))) + + def to_(self, device): + self.top_data = self.top_data.to(device) + self.top_index = self.top_index.to(device) + if isinstance(self.linear_index, torch.Tensor): + self.linear_index = self.linear_index.to(device) + + def state_dict(self): + return dict( + constructor=self.__module__ + '.' + + self.__class__.__name__ + '()', + k=self.k, + count=self.count, + data_shape=tuple(self.data_shape), + top_data=self.top_data.cpu().numpy(), + top_index=self.top_index.cpu().numpy(), + next=self.next, + linear_index=(self.linear_index.cpu().numpy() + if isinstance(self.linear_index, torch.Tensor) + else self.linear_index), + perm=self.perm) + + def set_state_dict(self, dic): + self.k = dic['k'].item() + self.count = dic['count'].item() + self.data_shape = tuple(dic['data_shape']) + self.top_data = torch.from_numpy(dic['top_data']) + self.top_index = torch.from_numpy(dic['top_index']) + self.next = dic['next'].item() + self.linear_index = (torch.from_numpy(dic['linear_index']) + if len(dic['linear_index'].shape) > 0 + else dic['linear_index'].item()) + +class RunningQuantile: + """ + Streaming randomized quantile computation for torch. + + Add any amount of data repeatedly via add(data). At any time, + quantile estimates (or old-style percentiles) can be read out using + quantiles(q) or percentiles(p). + + Accuracy scales according to resolution: the default is to + set resolution to be accurate to better than 0.1%, + while limiting storage to about 50,000 samples. + + Good for computing quantiles of huge data without using much memory. + Works well on arbitrary data with probability near 1. + + Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty + from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf + """ + + def __init__(self, resolution=6 * 1024, buffersize=None, seed=None, + state=None): + if state is not None: + self.set_state_dict(state) + return + self.depth = None + self.dtype = None + self.device = None + self.resolution = resolution + # Default buffersize: 128 samples (and smaller than resolution). + if buffersize is None: + buffersize = min(128, (resolution + 7) // 8) + self.buffersize = buffersize + self.samplerate = 1.0 + self.data = None + self.firstfree = [0] + self.randbits = torch.ByteTensor(resolution) + self.currentbit = len(self.randbits) - 1 + self.extremes = None + self.size = 0 + + def _lazy_init(self, incoming): + self.depth = incoming.shape[1] + self.dtype = incoming.dtype + self.device = incoming.device + self.data = [torch.zeros(self.depth, self.resolution, + dtype=self.dtype, device=self.device)] + self.extremes = torch.zeros(self.depth, 2, + dtype=self.dtype, device=self.device) + self.extremes[:,0] = float('inf') + self.extremes[:,-1] = -float('inf') + + def to_(self, device): + """Switches internal storage to specified device.""" + if device != self.device: + old_data = self.data + old_extremes = self.extremes + self.data = [d.to(device) for d in self.data] + self.extremes = self.extremes.to(device) + self.device = self.extremes.device + del old_data + del old_extremes + + def add(self, incoming): + if self.depth is None: + self._lazy_init(incoming) + assert len(incoming.shape) == 2 + assert incoming.shape[1] == self.depth, (incoming.shape[1], self.depth) + self.size += incoming.shape[0] + # Convert to a flat torch array. + if self.samplerate >= 1.0: + self._add_every(incoming) + return + # If we are sampling, then subsample a large chunk at a time. + self._scan_extremes(incoming) + chunksize = int(math.ceil(self.buffersize / self.samplerate)) + for index in range(0, len(incoming), chunksize): + batch = incoming[index:index+chunksize] + sample = sample_portion(batch, self.samplerate) + if len(sample): + self._add_every(sample) + + def _add_every(self, incoming): + supplied = len(incoming) + index = 0 + while index < supplied: + ff = self.firstfree[0] + available = self.data[0].shape[1] - ff + if available == 0: + if not self._shift(): + # If we shifted by subsampling, then subsample. + incoming = incoming[index:] + if self.samplerate >= 0.5: + # First time sampling - the data source is very large. + self._scan_extremes(incoming) + incoming = sample_portion(incoming, self.samplerate) + index = 0 + supplied = len(incoming) + ff = self.firstfree[0] + available = self.data[0].shape[1] - ff + copycount = min(available, supplied - index) + self.data[0][:,ff:ff + copycount] = torch.t( + incoming[index:index + copycount,:]) + self.firstfree[0] += copycount + index += copycount + + def _shift(self): + index = 0 + # If remaining space at the current layer is less than half prev + # buffer size (rounding up), then we need to shift it up to ensure + # enough space for future shifting. + while self.data[index].shape[1] - self.firstfree[index] < ( + -(-self.data[index-1].shape[1] // 2) if index else 1): + if index + 1 >= len(self.data): + return self._expand() + data = self.data[index][:,0:self.firstfree[index]] + data = data.sort()[0] + if index == 0 and self.samplerate >= 1.0: + self._update_extremes(data[:,0], data[:,-1]) + offset = self._randbit() + position = self.firstfree[index + 1] + subset = data[:,offset::2] + self.data[index + 1][:,position:position + subset.shape[1]] = subset + self.firstfree[index] = 0 + self.firstfree[index + 1] += subset.shape[1] + index += 1 + return True + + def _scan_extremes(self, incoming): + # When sampling, we need to scan every item still to get extremes + self._update_extremes( + torch.min(incoming, dim=0)[0], + torch.max(incoming, dim=0)[0]) + + def _update_extremes(self, minr, maxr): + self.extremes[:,0] = torch.min( + torch.stack([self.extremes[:,0], minr]), dim=0)[0] + self.extremes[:,-1] = torch.max( + torch.stack([self.extremes[:,-1], maxr]), dim=0)[0] + + def _randbit(self): + self.currentbit += 1 + if self.currentbit >= len(self.randbits): + self.randbits.random_(to=2) + self.currentbit = 0 + return self.randbits[self.currentbit] + + def state_dict(self): + return dict( + constructor=self.__module__ + '.' + + self.__class__.__name__ + '()', + resolution=self.resolution, + depth=self.depth, + buffersize=self.buffersize, + samplerate=self.samplerate, + data=[d.cpu().numpy()[:,:f].T + for d, f in zip(self.data, self.firstfree)], + sizes=[d.shape[1] for d in self.data], + extremes=self.extremes.cpu().numpy(), + size=self.size) + + def set_state_dict(self, dic): + self.resolution = int(dic['resolution']) + self.randbits = torch.ByteTensor(self.resolution) + self.currentbit = len(self.randbits) - 1 + self.depth = int(dic['depth']) + self.buffersize = int(dic['buffersize']) + self.samplerate = float(dic['samplerate']) + firstfree = [] + buffers = [] + for d, s in zip(dic['data'], dic['sizes']): + firstfree.append(d.shape[0]) + buf = numpy.zeros((d.shape[1], s), dtype=d.dtype) + buf[:,:d.shape[0]] = d.T + buffers.append(torch.from_numpy(buf)) + self.firstfree = firstfree + self.data = buffers + self.extremes = torch.from_numpy((dic['extremes'])) + self.size = int(dic['size']) + self.dtype = self.extremes.dtype + self.device = self.extremes.device + + def minmax(self): + if self.firstfree[0]: + self._scan_extremes(self.data[0][:,:self.firstfree[0]].t()) + return self.extremes.clone() + + def median(self): + return self.quantiles([0.5])[:,0] + + def mean(self): + return self.integrate(lambda x: x) / self.size + + def variance(self): + mean = self.mean()[:,None] + return self.integrate(lambda x: (x - mean).pow(2)) / (self.size - 1) + + def stdev(self): + return self.variance().sqrt() + + def _expand(self): + cap = self._next_capacity() + if cap > 0: + # First, make a new layer of the proper capacity. + self.data.insert(0, torch.zeros(self.depth, cap, + dtype=self.dtype, device=self.device)) + self.firstfree.insert(0, 0) + else: + # Unless we're so big we are just subsampling. + assert self.firstfree[0] == 0 + self.samplerate *= 0.5 + for index in range(1, len(self.data)): + # Scan for existing data that needs to be moved down a level. + amount = self.firstfree[index] + if amount == 0: + continue + position = self.firstfree[index-1] + # Move data down if it would leave enough empty space there + # This is the key invariant: enough empty space to fit half + # of the previous level's buffer size (rounding up) + if self.data[index-1].shape[1] - (amount + position) >= ( + -(-self.data[index-2].shape[1] // 2) if (index-1) else 1): + self.data[index-1][:,position:position + amount] = ( + self.data[index][:,:amount]) + self.firstfree[index-1] += amount + self.firstfree[index] = 0 + else: + # Scrunch the data if it would not. + data = self.data[index][:,:amount] + data = data.sort()[0] + if index == 1: + self._update_extremes(data[:,0], data[:,-1]) + offset = self._randbit() + scrunched = data[:,offset::2] + self.data[index][:,:scrunched.shape[1]] = scrunched + self.firstfree[index] = scrunched.shape[1] + return cap > 0 + + def _next_capacity(self): + cap = int(math.ceil(self.resolution * (0.67 ** len(self.data)))) + if cap < 2: + return 0 + # Round up to the nearest multiple of 8 for better GPU alignment. + cap = -8 * (-cap // 8) + return max(self.buffersize, cap) + + def _weighted_summary(self, sort=True): + if self.firstfree[0]: + self._scan_extremes(self.data[0][:,:self.firstfree[0]].t()) + size = sum(self.firstfree) + 2 + weights = torch.FloatTensor(size) # Floating point + summary = torch.zeros(self.depth, size, + dtype=self.dtype, device=self.device) + weights[0:2] = 0 + summary[:,0:2] = self.extremes + index = 2 + for level, ff in enumerate(self.firstfree): + if ff == 0: + continue + summary[:,index:index + ff] = self.data[level][:,:ff] + weights[index:index + ff] = 2.0 ** level + index += ff + assert index == summary.shape[1] + if sort: + summary, order = torch.sort(summary, dim=-1) + weights = weights[order.view(-1).cpu()].view(order.shape) + return (summary, weights) + + def quantiles(self, quantiles, old_style=False): + if self.size == 0: + return torch.full((self.depth, len(quantiles)), torch.nan) + summary, weights = self._weighted_summary() + cumweights = torch.cumsum(weights, dim=-1) - weights / 2 + if old_style: + # To be convenient with torch.percentile + cumweights -= cumweights[:,0:1].clone() + cumweights /= cumweights[:,-1:].clone() + else: + cumweights /= torch.sum(weights, dim=-1, keepdim=True) + result = torch.zeros(self.depth, len(quantiles), + dtype=self.dtype, device=self.device) + # numpy is needed for interpolation + if not hasattr(quantiles, 'cpu'): + quantiles = torch.Tensor(quantiles) + nq = quantiles.cpu().numpy() + ncw = cumweights.cpu().numpy() + nsm = summary.cpu().numpy() + for d in range(self.depth): + result[d] = torch.tensor(numpy.interp(nq, ncw[d], nsm[d]), + dtype=self.dtype, device=self.device) + return result + + def integrate(self, fun): + result = None + for level, ff in enumerate(self.firstfree): + if ff == 0: + continue + term = torch.sum( + fun(self.data[level][:,:ff]) * (2.0 ** level), + dim=-1) + if result is None: + result = term + else: + result += term + if result is not None: + result /= self.samplerate + return result + + def percentiles(self, percentiles): + return self.quantiles(percentiles, old_style=True) + + def readout(self, count=1001, old_style=True): + return self.quantiles( + torch.linspace(0.0, 1.0, count), old_style=old_style) + + def normalize(self, data): + ''' + Given input data as taken from the training distirbution, + normalizes every channel to reflect quantile values, + uniformly distributed, within [0, 1]. + ''' + assert self.size > 0 + assert data.shape[0] == self.depth + summary, weights = self._weighted_summary() + cumweights = torch.cumsum(weights, dim=-1) - weights / 2 + cumweights /= torch.sum(weights, dim=-1, keepdim=True) + result = torch.zeros_like(data).float() + # numpy is needed for interpolation + ndata = data.cpu().numpy().reshape((data.shape[0], -1)) + ncw = cumweights.cpu().numpy() + nsm = summary.cpu().numpy() + for d in range(self.depth): + normed = torch.tensor(numpy.interp(ndata[d], nsm[d], ncw[d]), + dtype=torch.float, device=data.device).clamp_(0.0, 1.0) + if len(data.shape) > 1: + normed = normed.view(*(data.shape[1:])) + result[d] = normed + return result + + +class RunningConditionalQuantile: + ''' + Equivalent to a map from conditions (any python hashable type) + to RunningQuantiles. The reason for the type is to allow limited + GPU memory to be exploited while counting quantile stats on many + different conditions, a few of which are common and which benefit + from GPU, but most of which are rare and would not all fit into + GPU RAM. + + To move a set of conditions to a device, use rcq.to_(device, conds). + Then in the future, move the tallied data to the device before + calling rcq.add, that is, rcq.add(cond, data.to(device)). + + To allow the caller to decide which conditions to allow to use GPU, + rcq.most_common_conditions(n) returns a list of the n most commonly + added conditions so far. + ''' + def __init__(self, resolution=6 * 1024, buffersize=None, seed=None, + state=None): + self.first_rq = None + self.call_stats = defaultdict(int) + self.running_quantiles = {} + if state is not None: + self.set_state_dict(state) + return + self.rq_args = dict(resolution=resolution, buffersize=buffersize, + seed=seed) + + def add(self, condition, incoming): + if condition not in self.running_quantiles: + self.running_quantiles[condition] = RunningQuantile(**self.rq_args) + if self.first_rq is None: + self.first_rq = self.running_quantiles[condition] + self.call_stats[condition] += 1 + rq = self.running_quantiles[condition] + # For performance reasons, the caller can move some conditions to + # the CPU if they are not among the most common conditions. + if rq.device is not None and (rq.device != incoming.device): + rq.to_(incoming.device) + self.running_quantiles[condition].add(incoming) + + def most_common_conditions(self, n): + return sorted(self.call_stats.keys(), + key=lambda c: -self.call_stats[c])[:n] + + def collected_add(self, conditions, incoming): + for c in conditions: + self.add(c, incoming) + + def conditional(self, c): + return self.running_quantiles[c] + + def collected_quantiles(self, conditions, quantiles, old_style=False): + result = torch.zeros( + size=(len(conditions), self.first_rq.depth, len(quantiles)), + dtype=self.first_rq.dtype, + device=self.first_rq.device) + for i, c in enumerate(conditions): + if c in self.running_quantiles: + result[i] = self.running_quantiles[c].quantiles( + quantiles, old_style) + return result + + def collected_normalize(self, conditions, values): + result = torch.zeros( + size=(len(conditions), values.shape[0], values.shape[1]), + dtype=torch.float, + device=self.first_rq.device) + for i, c in enumerate(conditions): + if c in self.running_quantiles: + result[i] = self.running_quantiles[c].normalize(values) + return result + + def to_(self, device, conditions=None): + if conditions is None: + conditions = self.running_quantiles.keys() + for cond in conditions: + if cond in self.running_quantiles: + self.running_quantiles[cond].to_(device) + + def state_dict(self): + conditions = sorted(self.running_quantiles.keys()) + result = dict( + constructor=self.__module__ + '.' + + self.__class__.__name__ + '()', + rq_args=self.rq_args, + conditions=conditions) + for i, c in enumerate(conditions): + result.update({ + '%d.%s' % (i, k): v + for k, v in self.running_quantiles[c].state_dict().items()}) + return result + + def set_state_dict(self, dic): + self.rq_args = dic['rq_args'].item() + conditions = list(dic['conditions']) + subdicts = defaultdict(dict) + for k, v in dic.items(): + if '.' in k: + p, s = k.split('.', 1) + subdicts[p][s] = v + self.running_quantiles = { + c: RunningQuantile(state=subdicts[str(i)]) + for i, c in enumerate(conditions)} + if conditions: + self.first_rq = self.running_quantiles[conditions[0]] + + # example usage: + # levels = rqc.conditional(()).quantiles(1 - fracs) + # denoms = 1 - rqc.collected_normalize(cats, levels) + # isects = 1 - rqc.collected_normalize(labels, levels) + # unions = fracs + denoms[cats] - isects + # iou = isects / unions + + + + +class RunningCrossCovariance: + ''' + Running computation. Use this when an off-diagonal block of the + covariance matrix is needed (e.g., when the whole covariance matrix + does not fit in the GPU). + + Chan-style numerically stable update of mean and full covariance matrix. + Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386 + ''' + def __init__(self, state=None): + if state is not None: + self.set_state_dict(state) + return + self.count = 0 + self._mean = None + self.cmom2 = None + self.v_cmom2 = None + + def add(self, a, b): + if len(a.shape) == 1: + a = a[None, :] + b = b[None, :] + assert(a.shape[0] == b.shape[0]) + if len(a.shape) > 2: + a, b = [d.view(d.shape[0], d.shape[1], -1).permute(0, 2, 1 + ).contiguous().view(-1, d.shape[1]) for d in [a, b]] + batch_count = a.shape[0] + batch_mean = [d.sum(0) / batch_count for d in [a, b]] + centered = [d - bm for d, bm in zip([a, b], batch_mean)] + # If more than 10 billion operations, divide into batches. + sub_batch = -(-(10 << 30) // (a.shape[1] * b.shape[1])) + # Initial batch. + if self._mean is None: + self.count = batch_count + self._mean = batch_mean + self.v_cmom2 = [c.pow(2).sum(0) for c in centered] + self.cmom2 = a.new(a.shape[1], b.shape[1]).zero_() + progress_addbmm(self.cmom2, centered[0][:,:,None], + centered[1][:,None,:], sub_batch) + return + # Update a batch using Chan-style update for numerical stability. + oldcount = self.count + self.count += batch_count + new_frac = float(batch_count) / self.count + # Update the mean according to the batch deviation from the old mean. + delta = [bm.sub_(m).mul_(new_frac) + for bm, m in zip(batch_mean, self._mean)] + for m, d in zip(self._mean, delta): + m.add_(d) + # Update the cross-covariance using the batch deviation + progress_addbmm(self.cmom2, centered[0][:,:,None], + centered[1][:,None,:], sub_batch) + self.cmom2.addmm_(alpha=new_frac * oldcount, + mat1=delta[0][:,None], mat2=delta[1][None,:]) + # Update the variance using the batch deviation + for c, vc2, d in zip(centered, self.v_cmom2, delta): + vc2.add_(c.pow(2).sum(0)) + vc2.add_(d.pow_(2).mul_(new_frac * oldcount)) + + def mean(self): + return self._mean + + def variance(self): + return [vc2 / (self.count - 1) for vc2 in self.v_cmom2] + + def stdev(self): + return [v.sqrt() for v in self.variance()] + + def covariance(self): + return self.cmom2 / (self.count - 1) + + def correlation(self): + covariance = self.covariance() + rstdev = [s.reciprocal() for s in self.stdev()] + cor = rstdev[0][:,None] * covariance * rstdev[1][None,:] + # Remove NaNs + cor[torch.isnan(cor)] = 0 + return cor + + def to_(self, device): + self._mean = [m.to(device) for m in self._mean] + self.v_cmom2 = [vcs.to(device) for vcs in self.v_cmom2] + self.cmom2 = self.cmom2.to(device) + + def state_dict(self): + return dict( + constructor=self.__module__ + '.' + + self.__class__.__name__ + '()', + count=self.count, + mean_a=self._mean[0].cpu().numpy(), + mean_b=self._mean[1].cpu().numpy(), + cmom2_a=self.v_cmom2[0].cpu().numpy(), + cmom2_b=self.v_cmom2[1].cpu().numpy(), + cmom2=self.cmom2.cpu().numpy()) + + def set_state_dict(self, dic): + self.count = dic['count'].item() + self._mean = [torch.from_numpy(dic[k]) for k in ['mean_a', 'mean_b']] + self.v_cmom2 = [torch.from_numpy(dic[k]) + for k in ['cmom2_a', 'cmom2_b']] + self.cmom2 = torch.from_numpy(dic['cmom2']) + +def progress_addbmm(accum, x, y, batch_size): + ''' + Break up very large adbmm operations into batches so progress can be seen. + ''' + from .progress import default_progress + if x.shape[0] <= batch_size: + return accum.addbmm_(x, y) + progress = default_progress(None) + for i in progress(range(0, x.shape[0], batch_size), desc='bmm'): + accum.addbmm_(x[i:i+batch_size], y[i:i+batch_size]) + return accum + + +def sample_portion(vec, p=0.5): + bits = torch.bernoulli(torch.zeros(vec.shape[0], dtype=torch.uint8, + device=vec.device), p) + return vec[bits] + +if __name__ == '__main__': + import warnings + warnings.filterwarnings("error") + import time + import argparse + parser = argparse.ArgumentParser( + description='Test things out') + parser.add_argument('--mode', default='cpu', help='cpu or cuda') + parser.add_argument('--test_size', type=int, default=1000000) + args = parser.parse_args() + + # An adverarial case: we keep finding more numbers in the middle + # as the stream goes on. + amount = args.test_size + quantiles = 1000 + data = numpy.arange(float(amount)) + data[1::2] = data[-1::-2] + (len(data) - 1) + data /= 2 + depth = 50 + test_cuda = torch.cuda.is_available() + alldata = data[:,None] + (numpy.arange(depth) * amount)[None, :] + actual_sum = torch.FloatTensor(numpy.sum(alldata * alldata, axis=0)) + amt = amount // depth + for r in range(depth): + numpy.random.shuffle(alldata[r*amt:r*amt+amt,r]) + if args.mode == 'cuda': + alldata = torch.cuda.FloatTensor(alldata) + dtype = torch.float + device = torch.device('cuda') + else: + alldata = torch.FloatTensor(alldata) + dtype = torch.float + device = None + starttime = time.time() + qc = RunningQuantile(resolution=6 * 1024) + qc.add(alldata) + # Test state dict + saved = qc.state_dict() + # numpy.savez('foo.npz', **saved) + # saved = numpy.load('foo.npz') + qc = RunningQuantile(state=saved) + assert not qc.device.type == 'cuda' + qc.add(alldata) + actual_sum *= 2 + ro = qc.readout(1001).cpu() + endtime = time.time() + gt = torch.linspace(0, amount, quantiles+1)[None,:] + ( + torch.arange(qc.depth, dtype=torch.float) * amount)[:,None] + maxreldev = torch.max(torch.abs(ro - gt) / amount) * quantiles + print("Maximum relative deviation among %d perentiles: %f" % ( + quantiles, maxreldev)) + minerr = torch.max(torch.abs(qc.minmax().cpu()[:,0] - + torch.arange(qc.depth, dtype=torch.float) * amount)) + maxerr = torch.max(torch.abs((qc.minmax().cpu()[:, -1] + 1) - + (torch.arange(qc.depth, dtype=torch.float) + 1) * amount)) + print("Minmax error %f, %f" % (minerr, maxerr)) + interr = torch.max(torch.abs(qc.integrate(lambda x: x * x).cpu() + - actual_sum) / actual_sum) + print("Integral error: %f" % interr) + medianerr = torch.max(torch.abs(qc.median() - + alldata.median(0)[0]) / alldata.median(0)[0]).cpu() + print("Median error: %f" % interr) + meanerr = torch.max( + torch.abs(qc.mean() - alldata.mean(0)) / alldata.mean(0)).cpu() + print("Mean error: %f" % meanerr) + varerr = torch.max( + torch.abs(qc.variance() - alldata.var(0)) / alldata.var(0)).cpu() + print("Variance error: %f" % varerr) + counterr = ((qc.integrate(lambda x: torch.ones(x.shape[-1]).cpu()) + - qc.size) / (0.0 + qc.size)).item() + print("Count error: %f" % counterr) + print("Time %f" % (endtime - starttime)) + # Algorithm is randomized, so some of these will fail with low probability. + assert maxreldev < 1.0 + assert minerr == 0.0 + assert maxerr == 0.0 + assert interr < 0.01 + assert abs(counterr) < 0.001 + print("OK") diff --git a/netdissect/sampler.py b/netdissect/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..72f1b46da117403c7f6ddcc1877bd9d70ded962b --- /dev/null +++ b/netdissect/sampler.py @@ -0,0 +1,134 @@ +''' +A sampler is just a list of integer listing the indexes of the +inputs in a data set to sample. For reproducibility, the +FixedRandomSubsetSampler uses a seeded prng to produce the same +sequence always. FixedSubsetSampler is just a wrapper for an +explicit list of integers. + +coordinate_sample solves another sampling problem: when testing +convolutional outputs, we can reduce data explosing by sampling +random points of the feature map rather than the entire feature map. +coordinate_sample does this in a deterministic way that is also +resolution-independent. +''' + +import numpy +import random +from torch.utils.data.sampler import Sampler + +class FixedSubsetSampler(Sampler): + """Represents a fixed sequence of data set indices. + Subsets can be created by specifying a subset of output indexes. + """ + def __init__(self, samples): + self.samples = samples + + def __iter__(self): + return iter(self.samples) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, key): + return self.samples[key] + + def subset(self, new_subset): + return FixedSubsetSampler(self.dereference(new_subset)) + + def dereference(self, indices): + ''' + Translate output sample indices (small numbers indexing the sample) + to input sample indices (larger number indexing the original full set) + ''' + return [self.samples[i] for i in indices] + + +class FixedRandomSubsetSampler(FixedSubsetSampler): + """Samples a fixed number of samples from the dataset, deterministically. + Arguments: + data_source, + sample_size, + seed (optional) + """ + def __init__(self, data_source, start=None, end=None, seed=1): + rng = random.Random(seed) + shuffled = list(range(len(data_source))) + rng.shuffle(shuffled) + self.data_source = data_source + super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end]) + + def class_subset(self, class_filter): + ''' + Returns only the subset matching the given rule. + ''' + if isinstance(class_filter, int): + rule = lambda d: d[1] == class_filter + else: + rule = class_filter + return self.subset([i for i, j in enumerate(self.samples) + if rule(self.data_source[j])]) + +def coordinate_sample(shape, sample_size, seeds, grid=13, seed=1, flat=False): + ''' + Returns a (end-start) sets of sample_size grid points within + the shape given. If the shape dimensions are a multiple of 'grid', + then sampled points within the same row will never be duplicated. + ''' + if flat: + sampind = numpy.zeros((len(seeds), sample_size), dtype=int) + else: + sampind = numpy.zeros((len(seeds), 2, sample_size), dtype=int) + assert sample_size <= grid + for j, seed in enumerate(seeds): + rng = numpy.random.RandomState(seed) + # Shuffle the 169 random grid squares, and pick :sample_size. + square_count = grid ** len(shape) + square = numpy.stack(numpy.unravel_index( + rng.choice(square_count, square_count)[:sample_size], + (grid,) * len(shape))) + # Then add a random offset to each x, y and put in the range [0...1) + # Notice this selects the same locations regardless of resolution. + uniform = (square + rng.uniform(size=square.shape)) / grid + # TODO: support affine scaling so that we can align receptive field + # centers exactly when sampling neurons in different layers. + coords = (uniform * numpy.array(shape)[:,None]).astype(int) + # Now take sample_size without replacement. We do this in a way + # such that if sample_size is decreased or increased up to 'grid', + # the selected points become a subset, not totally different points. + if flat: + sampind[j] = numpy.ravel_multi_index(coords, dims=shape) + else: + sampind[j] = coords + return sampind + +if __name__ == '__main__': + from numpy.testing import assert_almost_equal + # Test that coordinate_sample is deterministic, in-range, and scalable. + assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102)), + [[[14, 0, 12, 11, 8, 13, 11, 20, 7, 20], + [ 9, 22, 7, 11, 23, 18, 21, 15, 2, 5]]]) + assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 102)), + [[[ 7, 0, 6, 5, 4, 6, 5, 10, 3, 20 // 2], + [ 4, 11, 3, 5, 11, 9, 10, 7, 1, 5 // 2]]]) + assert_almost_equal(coordinate_sample((13, 13), 10, range(100, 102), + flat=True), + [[ 8, 24, 67, 103, 87, 79, 138, 94, 98, 53], + [ 95, 11, 81, 70, 63, 87, 75, 137, 40, 2+10*13]]) + assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 103), + flat=True), + [[ 95, 11, 81, 70, 63, 87, 75, 137, 40, 132], + [ 0, 78, 114, 111, 66, 45, 72, 73, 79, 135]]) + assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102), + flat=True), + [[373, 22, 319, 297, 231, 356, 307, 535, 184, 5+20*26]]) + # Test FixedRandomSubsetSampler + fss = FixedRandomSubsetSampler(range(10)) + assert len(fss) == 10 + assert_almost_equal(list(fss), [8, 0, 3, 4, 5, 2, 9, 6, 7, 1]) + fss = FixedRandomSubsetSampler(range(10), 3, 8) + assert len(fss) == 5 + assert_almost_equal(list(fss), [4, 5, 2, 9, 6]) + fss = FixedRandomSubsetSampler([(i, i % 3) for i in range(10)], + class_filter=1) + assert len(fss) == 3 + assert_almost_equal(list(fss), [4, 7, 1]) diff --git a/netdissect/segdata.py b/netdissect/segdata.py new file mode 100644 index 0000000000000000000000000000000000000000..f3cb6dfac8985d9c55344abbc26cc26c4862aa85 --- /dev/null +++ b/netdissect/segdata.py @@ -0,0 +1,74 @@ +import os, numpy, torch, json +from .parallelfolder import ParallelImageFolders +from torchvision import transforms +from torchvision.transforms.functional import to_tensor, normalize + +class FieldDef(object): + def __init__(self, field, index, bitshift, bitmask, labels): + self.field = field + self.index = index + self.bitshift = bitshift + self.bitmask = bitmask + self.labels = labels + +class MultiSegmentDataset(object): + ''' + Just like ClevrMulticlassDataset, but the second stream is a one-hot + segmentation tensor rather than a flat one-hot presence vector. + + MultiSegmentDataset('dataset/clevrseg', + imgdir='images/train/positive', + segdir='images/train/segmentation') + ''' + def __init__(self, directory, transform=None, + imgdir='img', segdir='seg', val=False, size=None): + self.segdataset = ParallelImageFolders( + [os.path.join(directory, imgdir), + os.path.join(directory, segdir)], + transform=transform) + self.fields = [] + with open(os.path.join(directory, 'labelnames.json'), 'r') as f: + for defn in json.load(f): + self.fields.append(FieldDef( + defn['field'], defn['index'], defn['bitshift'], + defn['bitmask'], defn['label'])) + self.labels = ['-'] # Reserve label 0 to mean "no label" + self.categories = [] + self.label_category = [0] + for fieldnum, f in enumerate(self.fields): + self.categories.append(f.field) + f.firstchannel = len(self.labels) + f.channels = len(f.labels) - 1 + for lab in f.labels[1:]: + self.labels.append(lab) + self.label_category.append(fieldnum) + # Reserve 25% of the dataset for validation. + first_val = int(len(self.segdataset) * 0.75) + self.val = val + self.first = first_val if val else 0 + self.length = len(self.segdataset) - first_val if val else first_val + # Truncate the dataset if requested. + if size: + self.length = min(size, self.length) + + def __len__(self): + return self.length + + def __getitem__(self, index): + img, segimg = self.segdataset[index + self.first] + segin = numpy.array(segimg, numpy.uint8, copy=False) + segout = torch.zeros(len(self.categories), + segin.shape[0], segin.shape[1], dtype=torch.int64) + for i, field in enumerate(self.fields): + fielddata = ((torch.from_numpy(segin[:, :, field.index]) + >> field.bitshift) & field.bitmask) + segout[i] = field.firstchannel + fielddata - 1 + bincount = numpy.bincount(segout.flatten(), + minlength=len(self.labels)) + return img, segout, bincount + +if __name__ == '__main__': + ds = MultiSegmentDataset('dataset/clevrseg') + print(ds[0]) + import pdb; pdb.set_trace() + diff --git a/netdissect/segmenter.py b/netdissect/segmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ebe364bc30f32581f0d560e11f08bfbd0d1731 --- /dev/null +++ b/netdissect/segmenter.py @@ -0,0 +1,581 @@ +# Usage as a simple differentiable segmenter base class + +import os, torch, numpy, json, glob +import skimage.morphology +from collections import OrderedDict +from netdissect import upsegmodel +from netdissect import segmodel as segmodel_module +from netdissect.easydict import EasyDict +from urllib.request import urlretrieve + +class BaseSegmenter: + def get_label_and_category_names(self): + ''' + Returns two lists: first, a list of tuples [(label, category), ...] + where the label and category are human-readable strings indicating + the meaning of a segmentation class. The 0th segmentation class + should be reserved for a label ('-') that means "no prediction." + The second list should just be a list of [category,...] listing + all categories in a canonical order. + ''' + raise NotImplemented() + + def segment_batch(self, tensor_images, downsample=1): + ''' + Returns a multilabel segmentation for the given batch of (RGB [-1...1]) + images. Each pixel of the result is a torch.long indicating a + predicted class number. Multiple classes can be predicted for + the same pixel: output shape is (n, multipred, y, x), where + multipred is 3, 5, or 6, for how many different predicted labels can + be given for each pixel (depending on whether subdivision is being + used). If downsample is specified, then the output y and x dimensions + are downsampled from the original image. + ''' + raise NotImplemented() + + def predict_single_class(self, tensor_images, classnum, downsample=1): + ''' + Given a batch of images (RGB, normalized to [-1...1]) and + a specific segmentation class number, returns a tuple with + (1) a differentiable ([0..1]) prediction score for the class + at every pixel of the input image. + (2) a binary mask showing where in the input image the + specified class is the best-predicted label for the pixel. + Does not work on subdivided labels. + ''' + raise NotImplemented() + +class UnifiedParsingSegmenter(BaseSegmenter): + ''' + This is a wrapper for a more complicated multi-class segmenter, + as described in https://arxiv.org/pdf/1807.10221.pdf, and as + released in https://github.com/CSAILVision/unifiedparsing. + For our purposes and to simplify processing, we do not use + whole-scene predictions, and we only consume part segmentations + for the three largest object classes (sky, building, person). + ''' + + def __init__(self, segsizes=None, segdiv=None): + # Create a segmentation model + if segsizes is None: + segsizes = [256] + if segdiv == None: + segdiv = 'undivided' + segvocab = 'upp' + segarch = ('resnet50', 'upernet') + epoch = 40 + segmodel = load_unified_parsing_segmentation_model( + segarch, segvocab, epoch) + segmodel.cuda() + self.segmodel = segmodel + self.segsizes = segsizes + self.segdiv = segdiv + mult = 1 + if self.segdiv == 'quad': + mult = 5 + self.divmult = mult + # Assign class numbers for parts. + first_partnumber = ( + (len(segmodel.labeldata['object']) - 1) * mult + 1 + + (len(segmodel.labeldata['material']) - 1)) + # We only use parts for these three types of objects, for efficiency. + partobjects = ['sky', 'building', 'person'] + partnumbers = {} + partnames = [] + objectnumbers = {k: v + for v, k in enumerate(segmodel.labeldata['object'])} + part_index_translation = [] + # We merge some classes. For example "door" is both an object + # and a part of a building. To avoid confusion, we just count + # such classes as objects, and add part scores to the same index. + for owner in partobjects: + part_list = segmodel.labeldata['object_part'][owner] + numeric_part_list = [] + for part in part_list: + if part in objectnumbers: + numeric_part_list.append(objectnumbers[part]) + elif part in partnumbers: + numeric_part_list.append(partnumbers[part]) + else: + partnumbers[part] = len(partnames) + first_partnumber + partnames.append(part) + numeric_part_list.append(partnumbers[part]) + part_index_translation.append(torch.tensor(numeric_part_list)) + self.objects_with_parts = [objectnumbers[obj] for obj in partobjects] + self.part_index = part_index_translation + self.part_names = partnames + # For now we'll just do object and material labels. + self.num_classes = 1 + ( + len(segmodel.labeldata['object']) - 1) * mult + ( + len(segmodel.labeldata['material']) - 1) + len(partnames) + self.num_object_classes = len(self.segmodel.labeldata['object']) - 1 + + def get_label_and_category_names(self, dataset=None): + ''' + Lists label and category names. + ''' + # Labels are ordered as follows: + # 0, [object labels] [divided object labels] [materials] [parts] + # The zero label is reserved to mean 'no prediction'. + if self.segdiv == 'quad': + suffixes = ['t', 'l', 'b', 'r'] + else: + suffixes = [] + divided_labels = [] + for suffix in suffixes: + divided_labels.extend([('%s-%s' % (label, suffix), 'part') + for label in self.segmodel.labeldata['object'][1:]]) + # Create the whole list of labels + labelcats = ( + [(label, 'object') + for label in self.segmodel.labeldata['object']] + + divided_labels + + [(label, 'material') + for label in self.segmodel.labeldata['material'][1:]] + + [(label, 'part') for label in self.part_names]) + return labelcats, ['object', 'part', 'material'] + + def raw_seg_prediction(self, tensor_images, downsample=1): + ''' + Generates a segmentation by applying multiresolution voting on + the segmentation model, using (rounded to 32 pixels) a set of + resolutions in the example benchmark code. + ''' + y, x = tensor_images.shape[2:] + b = len(tensor_images) + tensor_images = (tensor_images + 1) / 2 * 255 + tensor_images = torch.flip(tensor_images, (1,)) # BGR!!!? + tensor_images -= torch.tensor([102.9801, 115.9465, 122.7717]).to( + dtype=tensor_images.dtype, device=tensor_images.device + )[None,:,None,None] + seg_shape = (y // downsample, x // downsample) + # We want these to be multiples of 32 for the model. + sizes = [(s, s) for s in self.segsizes] + pred = {category: torch.zeros( + len(tensor_images), len(self.segmodel.labeldata[category]), + seg_shape[0], seg_shape[1]).cuda() + for category in ['object', 'material']} + part_pred = {partobj_index: torch.zeros( + len(tensor_images), len(partindex), + seg_shape[0], seg_shape[1]).cuda() + for partobj_index, partindex in enumerate(self.part_index)} + for size in sizes: + if size == tensor_images.shape[2:]: + resized = tensor_images + else: + resized = torch.nn.AdaptiveAvgPool2d(size)(tensor_images) + r_pred = self.segmodel( + dict(img=resized), seg_size=seg_shape) + for k in pred: + pred[k] += r_pred[k] + for k in part_pred: + part_pred[k] += r_pred['part'][k] + return pred, part_pred + + def segment_batch(self, tensor_images, downsample=1): + ''' + Returns a multilabel segmentation for the given batch of (RGB [-1...1]) + images. Each pixel of the result is a torch.long indicating a + predicted class number. Multiple classes can be predicted for + the same pixel: output shape is (n, multipred, y, x), where + multipred is 3, 5, or 6, for how many different predicted labels can + be given for each pixel (depending on whether subdivision is being + used). If downsample is specified, then the output y and x dimensions + are downsampled from the original image. + ''' + pred, part_pred = self.raw_seg_prediction(tensor_images, + downsample=downsample) + piece_channels = 2 if self.segdiv == 'quad' else 0 + y, x = tensor_images.shape[2:] + seg_shape = (y // downsample, x // downsample) + segs = torch.zeros(len(tensor_images), 3 + piece_channels, + seg_shape[0], seg_shape[1], + dtype=torch.long, device=tensor_images.device) + _, segs[:,0] = torch.max(pred['object'], dim=1) + # Get materials and translate to shared numbering scheme + _, segs[:,1] = torch.max(pred['material'], dim=1) + maskout = (segs[:,1] == 0) + segs[:,1] += (len(self.segmodel.labeldata['object']) - 1) * self.divmult + segs[:,1][maskout] = 0 + # Now deal with subparts of sky, buildings, people + for i, object_index in enumerate(self.objects_with_parts): + trans = self.part_index[i].to(segs.device) + # Get the argmax, and then translate to shared numbering scheme + seg = trans[torch.max(part_pred[i], dim=1)[1]] + # Only trust the parts where the prediction also predicts the + # owning object. + mask = (segs[:,0] == object_index) + segs[:,2][mask] = seg[mask] + + if self.segdiv == 'quad': + segs = self.expand_segment_quad(segs, self.segdiv) + return segs + + def predict_single_class(self, tensor_images, classnum, downsample=1): + ''' + Given a batch of images (RGB, normalized to [-1...1]) and + a specific segmentation class number, returns a tuple with + (1) a differentiable ([0..1]) prediction score for the class + at every pixel of the input image. + (2) a binary mask showing where in the input image the + specified class is the best-predicted label for the pixel. + Does not work on subdivided labels. + ''' + result = 0 + pred, part_pred = self.raw_seg_prediction(tensor_images, + downsample=downsample) + material_offset = (len(self.segmodel.labeldata['object']) - 1 + ) * self.divmult + if material_offset < classnum < material_offset + len( + self.segmodel.labeldata['material']): + return ( + pred['material'][:, classnum - material_offset], + pred['material'].max(dim=1)[1] == classnum - material_offset) + mask = None + if classnum < len(self.segmodel.labeldata['object']): + result = pred['object'][:, classnum] + mask = (pred['object'].max(dim=1)[1] == classnum) + # Some objects, like 'door', are also a part of other objects, + # so add the part prediction also. + for i, object_index in enumerate(self.objects_with_parts): + local_index = (self.part_index[i] == classnum).nonzero() + if len(local_index) == 0: + continue + local_index = local_index.item() + # Ignore part predictions outside the mask. (We could pay + # atttention to and penalize such predictions.) + mask2 = (pred['object'].max(dim=1)[1] == object_index) * ( + part_pred[i].max(dim=1)[1] == local_index) + if mask is None: + mask = mask2 + else: + mask = torch.max(mask, mask2) + result = result + (part_pred[i][:, local_index]) + assert result is not 0, 'unrecognized class %d' % classnum + return result, mask + + def expand_segment_quad(self, segs, segdiv='quad'): + shape = segs.shape + segs[:,3:] = segs[:,0:1] # start by copying the object channel + num_seg_labels = self.num_object_classes + # For every connected component present (using generator) + for i, mask in component_masks(segs[:,0:1]): + # Figure the bounding box of the label + top, bottom = mask.any(dim=1).nonzero()[[0, -1], 0] + left, right = mask.any(dim=0).nonzero()[[0, -1], 0] + # Chop the bounding box into four parts + vmid = (top + bottom + 1) // 2 + hmid = (left + right + 1) // 2 + # Construct top, bottom, right, left masks + quad_mask = mask[None,:,:].repeat(4, 1, 1) + quad_mask[0, vmid:, :] = 0 # top + quad_mask[1, :, hmid:] = 0 # right + quad_mask[2, :vmid, :] = 0 # bottom + quad_mask[3, :, :hmid] = 0 # left + quad_mask = quad_mask.long() + # Modify extra segmentation labels by offsetting + segs[i,3,:,:] += quad_mask[0] * num_seg_labels + segs[i,4,:,:] += quad_mask[1] * (2 * num_seg_labels) + segs[i,3,:,:] += quad_mask[2] * (3 * num_seg_labels) + segs[i,4,:,:] += quad_mask[3] * (4 * num_seg_labels) + # remove any components that were too small to subdivide + mask = segs[:,3:] <= self.num_object_classes + segs[:,3:][mask] = 0 + return segs + +class SemanticSegmenter(BaseSegmenter): + def __init__(self, modeldir=None, segarch=None, segvocab=None, + segsizes=None, segdiv=None, epoch=None): + # Create a segmentation model + if modeldir == None: + modeldir = 'dataset/segmodel' + if segvocab == None: + segvocab = 'baseline' + if segarch == None: + segarch = ('resnet50_dilated8', 'ppm_bilinear_deepsup') + if segdiv == None: + segdiv = 'undivided' + elif isinstance(segarch, str): + segarch = segarch.split(',') + segmodel = load_segmentation_model(modeldir, segarch, segvocab, epoch) + if segsizes is None: + segsizes = getattr(segmodel.meta, 'segsizes', [256]) + self.segsizes = segsizes + # Verify segmentation model to has every out_channel labeled. + assert len(segmodel.meta.labels) == list(c for c in segmodel.modules() + if isinstance(c, torch.nn.Conv2d))[-1].out_channels + segmodel.cuda() + self.segmodel = segmodel + self.segdiv = segdiv + # Image normalization + self.bgr = (segmodel.meta.imageformat.byteorder == 'BGR') + self.imagemean = torch.tensor(segmodel.meta.imageformat.mean) + self.imagestd = torch.tensor(segmodel.meta.imageformat.stdev) + # Map from labels to external indexes, and labels to channel sets. + self.labelmap = {'-': 0} + self.channelmap = {'-': []} + self.labels = [('-', '-')] + num_labels = 1 + self.num_underlying_classes = len(segmodel.meta.labels) + # labelmap maps names to external indexes. + for i, label in enumerate(segmodel.meta.labels): + if label.name not in self.channelmap: + self.channelmap[label.name] = [] + self.channelmap[label.name].append(i) + if getattr(label, 'internal', None) or label.name in self.labelmap: + continue + self.labelmap[label.name] = num_labels + num_labels += 1 + self.labels.append((label.name, label.category)) + # Each category gets its own independent softmax. + self.category_indexes = { category.name: + [i for i, label in enumerate(segmodel.meta.labels) + if label.category == category.name] + for category in segmodel.meta.categories } + # catindexmap maps names to category internal indexes + self.catindexmap = {} + for catname, indexlist in self.category_indexes.items(): + for index, i in enumerate(indexlist): + self.catindexmap[segmodel.meta.labels[i].name] = ( + (catname, index)) + # After the softmax, each category is mapped to external indexes. + self.category_map = { catname: + torch.tensor([ + self.labelmap.get(segmodel.meta.labels[ind].name, 0) + for ind in catindex]) + for catname, catindex in self.category_indexes.items()} + self.category_rules = segmodel.meta.categories + # Finally, naive subdivision can be applied. + mult = 1 + if self.segdiv == 'quad': + mult = 5 + suffixes = ['t', 'l', 'b', 'r'] + divided_labels = [] + for suffix in suffixes: + divided_labels.extend([('%s-%s' % (label, suffix), cat) + for label, cat in self.labels[1:]]) + self.channelmap.update({ + '%s-%s' % (label, suffix): self.channelmap[label] + for label, cat in self.labels[1:] }) + self.labels.extend(divided_labels) + # For examining a single class + self.channellist = [self.channelmap[name] for name, _ in self.labels] + + def get_label_and_category_names(self, dataset=None): + return self.labels, self.segmodel.categories + + def segment_batch(self, tensor_images, downsample=1): + return self.raw_segment_batch(tensor_images, downsample)[0] + + def raw_segment_batch(self, tensor_images, downsample=1): + pred = self.raw_seg_prediction(tensor_images, downsample) + catsegs = {} + for catkey, catindex in self.category_indexes.items(): + _, segs = torch.max(pred[:, catindex], dim=1) + catsegs[catkey] = segs + masks = {} + segs = torch.zeros(len(tensor_images), len(self.category_rules), + pred.shape[2], pred.shape[2], device=pred.device, + dtype=torch.long) + for i, cat in enumerate(self.category_rules): + catmap = self.category_map[cat.name].to(pred.device) + translated = catmap[catsegs[cat.name]] + if getattr(cat, 'mask', None) is not None: + if cat.mask not in masks: + maskcat, maskind = self.catindexmap[cat.mask] + masks[cat.mask] = (catsegs[maskcat] == maskind) + translated *= masks[cat.mask].long() + segs[:,i] = translated + if self.segdiv == 'quad': + segs = self.expand_segment_quad(segs, + self.num_underlying_classes, self.segdiv) + return segs, pred + + def raw_seg_prediction(self, tensor_images, downsample=1): + ''' + Generates a segmentation by applying multiresolution voting on + the segmentation model, using (rounded to 32 pixels) a set of + resolutions in the example benchmark code. + ''' + y, x = tensor_images.shape[2:] + b = len(tensor_images) + # Flip the RGB order if specified. + if self.bgr: + tensor_images = torch.flip(tensor_images, (1,)) + # Transform from our [-1..1] range to torch standard [0..1] range + # and then apply normalization. + tensor_images = ((tensor_images + 1) / 2 + ).sub_(self.imagemean[None,:,None,None].to(tensor_images.device) + ).div_(self.imagestd[None,:,None,None].to(tensor_images.device)) + # Output shape can be downsampled. + seg_shape = (y // downsample, x // downsample) + # We want these to be multiples of 32 for the model. + sizes = [(s, s) for s in self.segsizes] + pred = torch.zeros( + len(tensor_images), (self.num_underlying_classes), + seg_shape[0], seg_shape[1]).cuda() + for size in sizes: + if size == tensor_images.shape[2:]: + resized = tensor_images + else: + resized = torch.nn.AdaptiveAvgPool2d(size)(tensor_images) + raw_pred = self.segmodel( + dict(img_data=resized), segSize=seg_shape) + softmax_pred = torch.empty_like(raw_pred) + for catindex in self.category_indexes.values(): + softmax_pred[:, catindex] = torch.nn.functional.softmax( + raw_pred[:, catindex], dim=1) + pred += softmax_pred + return pred + + def expand_segment_quad(self, segs, num_seg_labels, segdiv='quad'): + shape = segs.shape + output = segs.repeat(1, 3, 1, 1) + # For every connected component present (using generator) + for i, mask in component_masks(segs): + # Figure the bounding box of the label + top, bottom = mask.any(dim=1).nonzero()[[0, -1], 0] + left, right = mask.any(dim=0).nonzero()[[0, -1], 0] + # Chop the bounding box into four parts + vmid = (top + bottom + 1) // 2 + hmid = (left + right + 1) // 2 + # Construct top, bottom, right, left masks + quad_mask = mask[None,:,:].repeat(4, 1, 1) + quad_mask[0, vmid:, :] = 0 # top + quad_mask[1, :, hmid:] = 0 # right + quad_mask[2, :vmid, :] = 0 # bottom + quad_mask[3, :, :hmid] = 0 # left + quad_mask = quad_mask.long() + # Modify extra segmentation labels by offsetting + output[i,1,:,:] += quad_mask[0] * num_seg_labels + output[i,2,:,:] += quad_mask[1] * (2 * num_seg_labels) + output[i,1,:,:] += quad_mask[2] * (3 * num_seg_labels) + output[i,2,:,:] += quad_mask[3] * (4 * num_seg_labels) + return output + + def predict_single_class(self, tensor_images, classnum, downsample=1): + ''' + Given a batch of images (RGB, normalized to [-1...1]) and + a specific segmentation class number, returns a tuple with + (1) a differentiable ([0..1]) prediction score for the class + at every pixel of the input image. + (2) a binary mask showing where in the input image the + specified class is the best-predicted label for the pixel. + Does not work on subdivided labels. + ''' + seg, pred = self.raw_segment_batch(tensor_images, + downsample=downsample) + result = pred[:,self.channellist[classnum]].sum(dim=1) + mask = (seg == classnum).max(1)[0] + return result, mask + +def component_masks(segmentation_batch): + ''' + Splits connected components into regions (slower, requires cpu). + ''' + npbatch = segmentation_batch.cpu().numpy() + for i in range(segmentation_batch.shape[0]): + labeled, num = skimage.morphology.label(npbatch[i][0], return_num=True) + labeled = torch.from_numpy(labeled).to(segmentation_batch.device) + for label in range(1, num): + yield i, (labeled == label) + +def load_unified_parsing_segmentation_model(segmodel_arch, segvocab, epoch): + segmodel_dir = 'dataset/segmodel/%s-%s-%s' % ((segvocab,) + segmodel_arch) + # Load json of class names and part/object structure + with open(os.path.join(segmodel_dir, 'labels.json')) as f: + labeldata = json.load(f) + nr_classes={k: len(labeldata[k]) + for k in ['object', 'scene', 'material']} + nr_classes['part'] = sum(len(p) for p in labeldata['object_part'].values()) + # Create a segmentation model + segbuilder = upsegmodel.ModelBuilder() + # example segmodel_arch = ('resnet101', 'upernet') + seg_encoder = segbuilder.build_encoder( + arch=segmodel_arch[0], + fc_dim=2048, + weights=os.path.join(segmodel_dir, 'encoder_epoch_%d.pth' % epoch)) + seg_decoder = segbuilder.build_decoder( + arch=segmodel_arch[1], + fc_dim=2048, use_softmax=True, + nr_classes=nr_classes, + weights=os.path.join(segmodel_dir, 'decoder_epoch_%d.pth' % epoch)) + segmodel = upsegmodel.SegmentationModule( + seg_encoder, seg_decoder, labeldata) + segmodel.categories = ['object', 'part', 'material'] + segmodel.eval() + return segmodel + +def load_segmentation_model(modeldir, segmodel_arch, segvocab, epoch=None): + # Load csv of class names + segmodel_dir = 'dataset/segmodel/%s-%s-%s' % ((segvocab,) + segmodel_arch) + with open(os.path.join(segmodel_dir, 'labels.json')) as f: + labeldata = EasyDict(json.load(f)) + # Automatically pick the last epoch available. + if epoch is None: + choices = [os.path.basename(n)[14:-4] for n in + glob.glob(os.path.join(segmodel_dir, 'encoder_epoch_*.pth'))] + epoch = max([int(c) for c in choices if c.isdigit()]) + # Create a segmentation model + segbuilder = segmodel_module.ModelBuilder() + # example segmodel_arch = ('resnet101', 'upernet') + seg_encoder = segbuilder.build_encoder( + arch=segmodel_arch[0], + fc_dim=2048, + weights=os.path.join(segmodel_dir, 'encoder_epoch_%d.pth' % epoch)) + seg_decoder = segbuilder.build_decoder( + arch=segmodel_arch[1], + fc_dim=2048, inference=True, num_class=len(labeldata.labels), + weights=os.path.join(segmodel_dir, 'decoder_epoch_%d.pth' % epoch)) + segmodel = segmodel_module.SegmentationModule(seg_encoder, seg_decoder, + torch.nn.NLLLoss(ignore_index=-1)) + segmodel.categories = [cat.name for cat in labeldata.categories] + segmodel.labels = [label.name for label in labeldata.labels] + categories = OrderedDict() + label_category = numpy.zeros(len(segmodel.labels), dtype=int) + for i, label in enumerate(labeldata.labels): + label_category[i] = segmodel.categories.index(label.category) + segmodel.meta = labeldata + segmodel.eval() + return segmodel + +def ensure_upp_segmenter_downloaded(directory): + baseurl = 'http://netdissect.csail.mit.edu/data/segmodel' + dirname = 'upp-resnet50-upernet' + files = ['decoder_epoch_40.pth', 'encoder_epoch_40.pth', 'labels.json'] + download_dir = os.path.join(directory, dirname) + os.makedirs(download_dir, exist_ok=True) + for fn in files: + if os.path.isfile(os.path.join(download_dir, fn)): + continue # Skip files already downloaded + url = '%s/%s/%s' % (baseurl, dirname, fn) + print('Downloading %s' % url) + urlretrieve(url, os.path.join(download_dir, fn)) + assert os.path.isfile(os.path.join(directory, dirname, 'labels.json')) + +def test_main(): + ''' + Test the unified segmenter. + ''' + from PIL import Image + testim = Image.open('script/testdata/test_church_242.jpg') + tensor_im = (torch.from_numpy(numpy.asarray(testim)).permute(2, 0, 1) + .float() / 255 * 2 - 1)[None, :, :, :].cuda() + segmenter = UnifiedParsingSegmenter() + seg = segmenter.segment_batch(tensor_im) + bc = torch.bincount(seg.view(-1)) + labels, cats = segmenter.get_label_and_category_names() + for label in bc.nonzero()[:,0]: + if label.item(): + # What is the prediction for this class? + pred, mask = segmenter.predict_single_class(tensor_im, label.item()) + assert mask.sum().item() == bc[label].item() + assert len(((seg == label).max(1)[0] - mask).nonzero()) == 0 + inside_pred = pred[mask].mean().item() + outside_pred = pred[~mask].mean().item() + print('%s (%s, #%d): %d pixels, pred %.2g inside %.2g outside' % + (labels[label.item()] + (label.item(), bc[label].item(), + inside_pred, outside_pred))) + +if __name__ == '__main__': + test_main() diff --git a/netdissect/segmodel/__init__.py b/netdissect/segmodel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76b40a0a36bc2976f185dbdc344c5a7c09b65920 --- /dev/null +++ b/netdissect/segmodel/__init__.py @@ -0,0 +1 @@ +from .models import ModelBuilder, SegmentationModule diff --git a/netdissect/segmodel/colors150.npy b/netdissect/segmodel/colors150.npy new file mode 100644 index 0000000000000000000000000000000000000000..2384b386dabded09c47329a360987a0d7f67d697 --- /dev/null +++ b/netdissect/segmodel/colors150.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9823be654b7a9c135e355952b580f567d5127e98d99ec792ebc349d5fc8c137 +size 578 diff --git a/netdissect/segmodel/models.py b/netdissect/segmodel/models.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb6f2ce21720722d5d8c9ee4f7e015ad06a9647 --- /dev/null +++ b/netdissect/segmodel/models.py @@ -0,0 +1,558 @@ +import torch +import torch.nn as nn +import torchvision +from . import resnet, resnext +try: + from lib.nn import SynchronizedBatchNorm2d +except ImportError: + from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d + + +class SegmentationModuleBase(nn.Module): + def __init__(self): + super(SegmentationModuleBase, self).__init__() + + def pixel_acc(self, pred, label): + _, preds = torch.max(pred, dim=1) + valid = (label >= 0).long() + acc_sum = torch.sum(valid * (preds == label).long()) + pixel_sum = torch.sum(valid) + acc = acc_sum.float() / (pixel_sum.float() + 1e-10) + return acc + + +class SegmentationModule(SegmentationModuleBase): + def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): + super(SegmentationModule, self).__init__() + self.encoder = net_enc + self.decoder = net_dec + self.crit = crit + self.deep_sup_scale = deep_sup_scale + + def forward(self, feed_dict, *, segSize=None): + if segSize is None: # training + if self.deep_sup_scale is not None: # use deep supervision technique + (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) + else: + pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) + + loss = self.crit(pred, feed_dict['seg_label']) + if self.deep_sup_scale is not None: + loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) + loss = loss + loss_deepsup * self.deep_sup_scale + + acc = self.pixel_acc(pred, feed_dict['seg_label']) + return loss, acc + else: # inference + pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) + return pred + + +def conv3x3(in_planes, out_planes, stride=1, has_bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=has_bias) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + return nn.Sequential( + conv3x3(in_planes, out_planes, stride), + SynchronizedBatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +class ModelBuilder(): + # custom weights initialization + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + #elif classname.find('Linear') != -1: + # m.weight.data.normal_(0.0, 0.0001) + + def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''): + pretrained = True if len(weights) == 0 else False + if arch == 'resnet34': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet34_dilated8': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=8) + elif arch == 'resnet34_dilated16': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=16) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet50_dilated8': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=8) + elif arch == 'resnet50_dilated16': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=16) + elif arch == 'resnet101': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet101_dilated8': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=8) + elif arch == 'resnet101_dilated16': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=16) + elif arch == 'resnext101': + orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnext) # we can still use class Resnet + else: + raise Exception('Architecture undefined!') + + # net_encoder.apply(self.weights_init) + if len(weights) > 0: + # print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_encoder + + def build_decoder(self, arch='ppm_bilinear_deepsup', + fc_dim=512, num_class=150, + weights='', inference=False, use_softmax=False): + if arch == 'c1_bilinear_deepsup': + net_decoder = C1BilinearDeepSup( + num_class=num_class, + fc_dim=fc_dim, + inference=inference, + use_softmax=use_softmax) + elif arch == 'c1_bilinear': + net_decoder = C1Bilinear( + num_class=num_class, + fc_dim=fc_dim, + inference=inference, + use_softmax=use_softmax) + elif arch == 'ppm_bilinear': + net_decoder = PPMBilinear( + num_class=num_class, + fc_dim=fc_dim, + inference=inference, + use_softmax=use_softmax) + elif arch == 'ppm_bilinear_deepsup': + net_decoder = PPMBilinearDeepsup( + num_class=num_class, + fc_dim=fc_dim, + inference=inference, + use_softmax=use_softmax) + elif arch == 'upernet_lite': + net_decoder = UPerNet( + num_class=num_class, + fc_dim=fc_dim, + inference=inference, + use_softmax=use_softmax, + fpn_dim=256) + elif arch == 'upernet': + net_decoder = UPerNet( + num_class=num_class, + fc_dim=fc_dim, + inference=inference, + use_softmax=use_softmax, + fpn_dim=512) + elif arch == 'upernet_tmp': + net_decoder = UPerNetTmp( + num_class=num_class, + fc_dim=fc_dim, + inference=inference, + use_softmax=use_softmax, + fpn_dim=512) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(self.weights_init) + if len(weights) > 0: + # print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_decoder + + +class Resnet(nn.Module): + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + + +class ResnetDilated(nn.Module): + def __init__(self, orig_resnet, dilate_scale=8): + super(ResnetDilated, self).__init__() + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply( + partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=4)) + elif dilate_scale == 16: + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=2)) + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate//2, dilate//2) + m.padding = (dilate//2, dilate//2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + + +# last conv, bilinear upsample +class C1BilinearDeepSup(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False): + super(C1BilinearDeepSup, self).__init__() + self.use_softmax = use_softmax + self.inference = inference + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.inference or self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + if self.use_softmax: + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# last conv, bilinear upsample +class C1Bilinear(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False): + super(C1Bilinear, self).__init__() + self.use_softmax = use_softmax + self.inference = inference + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.inference or self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + if self.use_softmax: + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + + return x + + +# pyramid pooling, bilinear upsample +class PPMBilinear(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)): + super(PPMBilinear, self).__init__() + self.use_softmax = use_softmax + self.inference = inference + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + SynchronizedBatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim+len(pool_scales)*512, 512, + kernel_size=3, padding=1, bias=False), + SynchronizedBatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + x = self.conv_last(ppm_out) + + if self.inference or self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + if self.use_softmax: + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + return x + + +# pyramid pooling, bilinear upsample +class PPMBilinearDeepsup(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)): + super(PPMBilinearDeepsup, self).__init__() + self.use_softmax = use_softmax + self.inference = inference + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + SynchronizedBatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim+len(pool_scales)*512, 512, + kernel_size=3, padding=1, bias=False), + SynchronizedBatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.dropout_deepsup = nn.Dropout2d(0.1) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + x = self.conv_last(ppm_out) + + if self.inference or self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + if self.use_softmax: + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.dropout_deepsup(_) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# upernet +class UPerNet(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6), + fpn_inplanes=(256,512,1024,2048), fpn_dim=256): + super(UPerNet, self).__init__() + self.use_softmax = use_softmax + self.inference = inference + + # PPM Module + self.ppm_pooling = [] + self.ppm_conv = [] + + for scale in pool_scales: + self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) + self.ppm_conv.append(nn.Sequential( + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + SynchronizedBatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm_pooling = nn.ModuleList(self.ppm_pooling) + self.ppm_conv = nn.ModuleList(self.ppm_conv) + self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) + + # FPN Module + self.fpn_in = [] + for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer + self.fpn_in.append(nn.Sequential( + nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), + SynchronizedBatchNorm2d(fpn_dim), + nn.ReLU(inplace=True) + )) + self.fpn_in = nn.ModuleList(self.fpn_in) + + self.fpn_out = [] + for i in range(len(fpn_inplanes) - 1): # skip the top layer + self.fpn_out.append(nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + )) + self.fpn_out = nn.ModuleList(self.fpn_out) + + self.conv_last = nn.Sequential( + conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), + nn.Conv2d(fpn_dim, num_class, kernel_size=1) + ) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): + ppm_out.append(pool_conv(nn.functional.interploate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False))) + ppm_out = torch.cat(ppm_out, 1) + f = self.ppm_last_conv(ppm_out) + + fpn_feature_list = [f] + for i in reversed(range(len(conv_out) - 1)): + conv_x = conv_out[i] + conv_x = self.fpn_in[i](conv_x) # lateral branch + + f = nn.functional.interpolate( + f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch + f = conv_x + f + + fpn_feature_list.append(self.fpn_out[i](f)) + + fpn_feature_list.reverse() # [P2 - P5] + output_size = fpn_feature_list[0].size()[2:] + fusion_list = [fpn_feature_list[0]] + for i in range(1, len(fpn_feature_list)): + fusion_list.append(nn.functional.interpolate( + fpn_feature_list[i], + output_size, + mode='bilinear', align_corners=False)) + fusion_out = torch.cat(fusion_list, 1) + x = self.conv_last(fusion_out) + + if self.inference or self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + if self.use_softmax: + x = nn.functional.softmax(x, dim=1) + return x + + x = nn.functional.log_softmax(x, dim=1) + + return x diff --git a/netdissect/segmodel/object150_info.csv b/netdissect/segmodel/object150_info.csv new file mode 100644 index 0000000000000000000000000000000000000000..8b34d8f3874a38b96894863c5458a7c3c2b0e2e6 --- /dev/null +++ b/netdissect/segmodel/object150_info.csv @@ -0,0 +1,151 @@ +Idx,Ratio,Train,Val,Stuff,Name +1,0.1576,11664,1172,1,wall +2,0.1072,6046,612,1,building;edifice +3,0.0878,8265,796,1,sky +4,0.0621,9336,917,1,floor;flooring +5,0.0480,6678,641,0,tree +6,0.0450,6604,643,1,ceiling +7,0.0398,4023,408,1,road;route +8,0.0231,1906,199,0,bed +9,0.0198,4688,460,0,windowpane;window +10,0.0183,2423,225,1,grass +11,0.0181,2874,294,0,cabinet +12,0.0166,3068,310,1,sidewalk;pavement +13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul +14,0.0151,1804,190,1,earth;ground +15,0.0118,6666,796,0,door;double;door +16,0.0110,4269,411,0,table +17,0.0109,1691,160,1,mountain;mount +18,0.0104,3999,441,0,plant;flora;plant;life +19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall +20,0.0103,3261,318,0,chair +21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar +22,0.0074,709,75,1,water +23,0.0067,3296,315,0,painting;picture +24,0.0065,1191,106,0,sofa;couch;lounge +25,0.0061,1516,162,0,shelf +26,0.0060,667,69,1,house +27,0.0053,651,57,1,sea +28,0.0052,1847,224,0,mirror +29,0.0046,1158,128,1,rug;carpet;carpeting +30,0.0044,480,44,1,field +31,0.0044,1172,98,0,armchair +32,0.0044,1292,184,0,seat +33,0.0033,1386,138,0,fence;fencing +34,0.0031,698,61,0,desk +35,0.0030,781,73,0,rock;stone +36,0.0027,380,43,0,wardrobe;closet;press +37,0.0026,3089,302,0,lamp +38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub +39,0.0024,804,99,0,railing;rail +40,0.0023,1453,153,0,cushion +41,0.0023,411,37,0,base;pedestal;stand +42,0.0022,1440,162,0,box +43,0.0022,800,77,0,column;pillar +44,0.0020,2650,298,0,signboard;sign +45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser +46,0.0019,367,36,0,counter +47,0.0018,311,30,1,sand +48,0.0018,1181,122,0,sink +49,0.0018,287,23,1,skyscraper +50,0.0018,468,38,0,fireplace;hearth;open;fireplace +51,0.0018,402,43,0,refrigerator;icebox +52,0.0018,130,12,1,grandstand;covered;stand +53,0.0018,561,64,1,path +54,0.0017,880,102,0,stairs;steps +55,0.0017,86,12,1,runway +56,0.0017,172,11,0,case;display;case;showcase;vitrine +57,0.0017,198,18,0,pool;table;billiard;table;snooker;table +58,0.0017,930,109,0,pillow +59,0.0015,139,18,0,screen;door;screen +60,0.0015,564,52,1,stairway;staircase +61,0.0015,320,26,1,river +62,0.0015,261,29,1,bridge;span +63,0.0014,275,22,0,bookcase +64,0.0014,335,60,0,blind;screen +65,0.0014,792,75,0,coffee;table;cocktail;table +66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne +67,0.0014,1309,138,0,flower +68,0.0013,1112,113,0,book +69,0.0013,266,27,1,hill +70,0.0013,659,66,0,bench +71,0.0012,331,31,0,countertop +72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove +73,0.0012,369,36,0,palm;palm;tree +74,0.0012,144,9,0,kitchen;island +75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system +76,0.0010,324,33,0,swivel;chair +77,0.0009,304,27,0,boat +78,0.0009,170,20,0,bar +79,0.0009,68,6,0,arcade;machine +80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty +81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle +82,0.0008,492,49,0,towel +83,0.0008,2510,269,0,light;light;source +84,0.0008,440,39,0,truck;motortruck +85,0.0008,147,18,1,tower +86,0.0008,583,56,0,chandelier;pendant;pendent +87,0.0007,533,61,0,awning;sunshade;sunblind +88,0.0007,1989,239,0,streetlight;street;lamp +89,0.0007,71,5,0,booth;cubicle;stall;kiosk +90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box +91,0.0007,135,12,0,airplane;aeroplane;plane +92,0.0007,83,5,1,dirt;track +93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes +94,0.0006,1003,104,0,pole +95,0.0006,182,12,1,land;ground;soil +96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail +97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway +98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock +99,0.0006,965,114,0,bottle +100,0.0006,117,13,0,buffet;counter;sideboard +101,0.0006,354,35,0,poster;posting;placard;notice;bill;card +102,0.0006,108,9,1,stage +103,0.0006,557,55,0,van +104,0.0006,52,4,0,ship +105,0.0005,99,5,0,fountain +106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter +107,0.0005,292,31,0,canopy +108,0.0005,77,9,0,washer;automatic;washer;washing;machine +109,0.0005,340,38,0,plaything;toy +110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium +111,0.0005,465,49,0,stool +112,0.0005,50,4,0,barrel;cask +113,0.0005,622,75,0,basket;handbasket +114,0.0005,80,9,1,waterfall;falls +115,0.0005,59,3,0,tent;collapsible;shelter +116,0.0005,531,72,0,bag +117,0.0005,282,30,0,minibike;motorbike +118,0.0005,73,7,0,cradle +119,0.0005,435,44,0,oven +120,0.0005,136,25,0,ball +121,0.0005,116,24,0,food;solid;food +122,0.0004,266,31,0,step;stair +123,0.0004,58,12,0,tank;storage;tank +124,0.0004,418,83,0,trade;name;brand;name;brand;marque +125,0.0004,319,43,0,microwave;microwave;oven +126,0.0004,1193,139,0,pot;flowerpot +127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna +128,0.0004,347,36,0,bicycle;bike;wheel;cycle +129,0.0004,52,5,1,lake +130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine +131,0.0004,108,13,0,screen;silver;screen;projection;screen +132,0.0004,201,30,0,blanket;cover +133,0.0004,285,21,0,sculpture +134,0.0004,268,27,0,hood;exhaust;hood +135,0.0003,1020,108,0,sconce +136,0.0003,1282,122,0,vase +137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight +138,0.0003,453,57,0,tray +139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin +140,0.0003,397,44,0,fan +141,0.0003,92,8,1,pier;wharf;wharfage;dock +142,0.0003,228,18,0,crt;screen +143,0.0003,570,59,0,plate +144,0.0003,217,22,0,monitor;monitoring;device +145,0.0003,206,19,0,bulletin;board;notice;board +146,0.0003,130,14,0,shower +147,0.0003,178,28,0,radiator +148,0.0002,504,57,0,glass;drinking;glass +149,0.0002,775,96,0,clock +150,0.0002,421,56,0,flag diff --git a/netdissect/segmodel/resnet.py b/netdissect/segmodel/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5fdf82fafa3058c5f00074d55fbb1e584d5865 --- /dev/null +++ b/netdissect/segmodel/resnet.py @@ -0,0 +1,235 @@ +import os +import sys +import torch +import torch.nn as nn +import math +try: + from lib.nn import SynchronizedBatchNorm2d +except ImportError: + from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + + +__all__ = ['ResNet', 'resnet50', 'resnet101'] # resnet101 is coming soon! + + +model_urls = { + 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', + 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = SynchronizedBatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = SynchronizedBatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = SynchronizedBatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = SynchronizedBatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, SynchronizedBatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + SynchronizedBatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + +''' +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet34'])) + return model +''' + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet50']), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet101']), strict=False) + return model + +# def resnet152(pretrained=False, **kwargs): +# """Constructs a ResNet-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on Places +# """ +# model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnet152'])) +# return model + +def load_url(url, model_dir='./pretrained', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) diff --git a/netdissect/segmodel/resnext.py b/netdissect/segmodel/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbb7461a6c8eb126717967cdca5d5ce392aecea --- /dev/null +++ b/netdissect/segmodel/resnext.py @@ -0,0 +1,182 @@ +import os +import sys +import torch +import torch.nn as nn +import math +try: + from lib.nn import SynchronizedBatchNorm2d +except ImportError: + from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + + +__all__ = ['ResNeXt', 'resnext101'] # support resnext 101 + + +model_urls = { + #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', + 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class GroupBottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): + super(GroupBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) + self.bn3 = SynchronizedBatchNorm2d(planes * 2) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNeXt(nn.Module): + + def __init__(self, block, layers, groups=32, num_classes=1000): + self.inplanes = 128 + super(ResNeXt, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = SynchronizedBatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = SynchronizedBatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = SynchronizedBatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) + self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) + self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) + self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(1024 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, SynchronizedBatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, groups=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + SynchronizedBatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, groups, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +''' +def resnext50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext50']), strict=False) + return model +''' + + +def resnext101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext101']), strict=False) + return model + + +# def resnext152(pretrained=False, **kwargs): +# """Constructs a ResNeXt-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on Places +# """ +# model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnext152'])) +# return model + + +def load_url(url, model_dir='./pretrained', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) diff --git a/netdissect/segviz.py b/netdissect/segviz.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb954317aaf0fd6e31b6216cc7a59f01a5fb0bd --- /dev/null +++ b/netdissect/segviz.py @@ -0,0 +1,283 @@ +import numpy, scipy + +def segment_visualization(seg, size): + result = numpy.zeros((seg.shape[1] * seg.shape[2], 3), dtype=numpy.uint8) + flatseg = seg.reshape(seg.shape[0], seg.shape[1] * seg.shape[2]) + bc = numpy.bincount(flatseg.flatten()) + top = numpy.argsort(-bc) + # In a multilabel segmentation, we can't draw everything. + # Draw the fewest-pixel labels last. (We could pick the opposite order.) + for label in top: + if label == 0: + continue + if bc[label] == 0: + break + bitmap = ((flatseg == label).sum(axis=0) > 0) + result[bitmap] = high_contrast_arr[label % len(high_contrast_arr)] + result = result.reshape((seg.shape[1], seg.shape[2], 3)) + if seg.shape[1:] != size: + result = scipy.misc.imresize(result, size, interp='nearest') + return result + +# A palette that maximizes perceptual contrast between entries. +# https://stackoverflow.com/questions/33295120 +high_contrast = [ + [0, 0, 0], [255, 255, 0], [28, 230, 255], [255, 52, 255], + [255, 74, 70], [0, 137, 65], [0, 111, 166], [163, 0, 89], + [255, 219, 229], [122, 73, 0], [0, 0, 166], [99, 255, 172], + [183, 151, 98], [0, 77, 67], [143, 176, 255], [153, 125, 135], + [90, 0, 7], [128, 150, 147], [254, 255, 230], [27, 68, 0], + [79, 198, 1], [59, 93, 255], [74, 59, 83], [255, 47, 128], + [97, 97, 90], [186, 9, 0], [107, 121, 0], [0, 194, 160], + [255, 170, 146], [255, 144, 201], [185, 3, 170], [209, 97, 0], + [221, 239, 255], [0, 0, 53], [123, 79, 75], [161, 194, 153], + [48, 0, 24], [10, 166, 216], [1, 51, 73], [0, 132, 111], + [55, 33, 1], [255, 181, 0], [194, 255, 237], [160, 121, 191], + [204, 7, 68], [192, 185, 178], [194, 255, 153], [0, 30, 9], + [0, 72, 156], [111, 0, 98], [12, 189, 102], [238, 195, 255], + [69, 109, 117], [183, 123, 104], [122, 135, 161], [120, 141, 102], + [136, 85, 120], [250, 208, 159], [255, 138, 154], [209, 87, 160], + [190, 196, 89], [69, 102, 72], [0, 134, 237], [136, 111, 76], + [52, 54, 45], [180, 168, 189], [0, 166, 170], [69, 44, 44], + [99, 99, 117], [163, 200, 201], [255, 145, 63], [147, 138, 129], + [87, 83, 41], [0, 254, 207], [176, 91, 111], [140, 208, 255], + [59, 151, 0], [4, 247, 87], [200, 161, 161], [30, 110, 0], + [121, 0, 215], [167, 117, 0], [99, 103, 169], [160, 88, 55], + [107, 0, 44], [119, 38, 0], [215, 144, 255], [155, 151, 0], + [84, 158, 121], [255, 246, 159], [32, 22, 37], [114, 65, 143], + [188, 35, 255], [153, 173, 192], [58, 36, 101], [146, 35, 41], + [91, 69, 52], [253, 232, 220], [64, 78, 85], [0, 137, 163], + [203, 126, 152], [164, 232, 4], [50, 78, 114], [106, 58, 76], + [131, 171, 88], [0, 28, 30], [209, 247, 206], [0, 75, 40], + [200, 208, 246], [163, 164, 137], [128, 108, 102], [34, 40, 0], + [191, 86, 80], [232, 48, 0], [102, 121, 109], [218, 0, 124], + [255, 26, 89], [138, 219, 180], [30, 2, 0], [91, 78, 81], + [200, 149, 197], [50, 0, 51], [255, 104, 50], [102, 225, 211], + [207, 205, 172], [208, 172, 148], [126, 211, 121], [1, 44, 88], + [122, 123, 255], [214, 142, 1], [53, 51, 57], [120, 175, 161], + [254, 178, 198], [117, 121, 124], [131, 115, 147], [148, 58, 77], + [181, 244, 255], [210, 220, 213], [149, 86, 189], [106, 113, 74], + [0, 19, 37], [2, 82, 95], [10, 163, 247], [233, 129, 118], + [219, 213, 221], [94, 188, 209], [61, 79, 68], [126, 100, 5], + [2, 104, 78], [150, 43, 117], [141, 133, 70], [150, 149, 197], + [231, 115, 206], [216, 106, 120], [62, 137, 190], [202, 131, 78], + [81, 138, 135], [91, 17, 60], [85, 129, 59], [231, 4, 196], + [0, 0, 95], [169, 115, 153], [75, 129, 96], [89, 115, 138], + [255, 93, 167], [247, 201, 191], [100, 49, 39], [81, 58, 1], + [107, 148, 170], [81, 160, 88], [164, 91, 2], [29, 23, 2], + [226, 0, 39], [231, 171, 99], [76, 96, 1], [156, 105, 102], + [100, 84, 123], [151, 151, 158], [0, 106, 102], [57, 20, 6], + [244, 215, 73], [0, 69, 210], [0, 108, 49], [221, 182, 208], + [124, 101, 113], [159, 178, 164], [0, 216, 145], [21, 160, 138], + [188, 101, 233], [255, 255, 254], [198, 220, 153], [32, 59, 60], + [103, 17, 144], [107, 58, 100], [245, 225, 255], [255, 160, 242], + [204, 170, 53], [55, 69, 39], [139, 180, 0], [121, 120, 104], + [198, 0, 90], [59, 0, 10], [200, 98, 64], [41, 96, 124], + [64, 35, 52], [125, 90, 68], [204, 184, 124], [184, 129, 131], + [170, 81, 153], [181, 214, 195], [163, 132, 105], [159, 148, 240], + [167, 69, 113], [184, 148, 166], [113, 187, 140], [0, 180, 51], + [120, 158, 201], [109, 128, 186], [149, 63, 0], [94, 255, 3], + [228, 255, 252], [27, 225, 119], [188, 177, 229], [118, 145, 47], + [0, 49, 9], [0, 96, 205], [210, 0, 150], [137, 85, 99], + [41, 32, 29], [91, 50, 19], [167, 111, 66], [137, 65, 46], + [26, 58, 42], [73, 75, 90], [168, 140, 133], [244, 171, 170], + [163, 243, 171], [0, 198, 200], [234, 139, 102], [149, 138, 159], + [189, 201, 210], [159, 160, 100], [190, 71, 0], [101, 129, 136], + [131, 164, 133], [69, 60, 35], [71, 103, 93], [58, 63, 0], + [6, 18, 3], [223, 251, 113], [134, 142, 126], [152, 208, 88], + [108, 143, 125], [215, 191, 194], [60, 62, 110], [216, 61, 102], + [47, 93, 155], [108, 94, 70], [210, 91, 136], [91, 101, 108], + [0, 181, 127], [84, 92, 70], [134, 96, 151], [54, 93, 37], + [37, 47, 153], [0, 204, 255], [103, 78, 96], [252, 0, 156], + [146, 137, 107], [30, 35, 36], [222, 201, 178], [157, 73, 72], + [133, 171, 180], [52, 33, 66], [208, 150, 133], [164, 172, 172], + [0, 255, 255], [174, 156, 134], [116, 42, 51], [14, 114, 197], + [175, 216, 236], [192, 100, 185], [145, 2, 140], [254, 237, 191], + [255, 183, 137], [156, 184, 228], [175, 255, 209], [42, 54, 76], + [79, 74, 67], [100, 112, 149], [52, 187, 255], [128, 119, 129], + [146, 0, 3], [179, 165, 167], [1, 134, 21], [241, 255, 200], + [151, 111, 92], [255, 59, 193], [255, 95, 107], [7, 125, 132], + [245, 109, 147], [87, 113, 218], [78, 30, 42], [131, 0, 85], + [2, 211, 70], [190, 69, 45], [0, 144, 94], [190, 0, 40], + [110, 150, 227], [0, 118, 153], [254, 201, 109], [156, 106, 125], + [63, 161, 184], [137, 61, 227], [121, 180, 214], [127, 212, 217], + [103, 81, 187], [178, 141, 45], [226, 122, 5], [221, 156, 184], + [170, 188, 122], [152, 0, 52], [86, 26, 2], [143, 127, 0], + [99, 80, 0], [205, 125, 174], [138, 94, 45], [255, 179, 225], + [107, 100, 102], [198, 211, 0], [1, 0, 226], [136, 236, 105], + [143, 204, 190], [33, 0, 28], [81, 31, 77], [227, 246, 227], + [255, 142, 177], [107, 79, 41], [163, 127, 70], [106, 89, 80], + [31, 42, 26], [4, 120, 77], [16, 24, 53], [230, 224, 208], + [255, 116, 254], [0, 164, 95], [143, 93, 248], [75, 0, 89], + [65, 47, 35], [216, 147, 158], [219, 157, 114], [96, 65, 67], + [181, 186, 206], [152, 158, 183], [210, 196, 219], [165, 135, 175], + [119, 215, 150], [127, 140, 148], [255, 155, 3], [85, 81, 150], + [49, 221, 174], [116, 182, 113], [128, 38, 71], [42, 55, 63], + [1, 74, 104], [105, 102, 40], [76, 123, 109], [0, 44, 39], + [122, 69, 34], [59, 88, 89], [229, 211, 129], [255, 243, 255], + [103, 159, 160], [38, 19, 0], [44, 87, 66], [145, 49, 175], + [175, 93, 136], [199, 112, 106], [97, 171, 31], [140, 242, 212], + [197, 217, 184], [159, 255, 251], [191, 69, 204], [73, 57, 65], + [134, 59, 96], [185, 0, 118], [0, 49, 119], [197, 130, 210], + [193, 179, 148], [96, 43, 112], [136, 120, 104], [186, 191, 176], + [3, 0, 18], [209, 172, 254], [127, 222, 254], [75, 92, 113], + [163, 160, 151], [230, 109, 83], [99, 123, 93], [146, 190, 165], + [0, 248, 179], [190, 221, 255], [61, 181, 167], [221, 50, 72], + [182, 228, 222], [66, 119, 69], [89, 140, 90], [185, 76, 89], + [129, 129, 213], [148, 136, 139], [254, 214, 189], [83, 109, 49], + [110, 255, 146], [228, 232, 255], [32, 226, 0], [255, 208, 242], + [76, 131, 161], [189, 115, 34], [145, 92, 78], [140, 71, 135], + [2, 81, 23], [162, 170, 69], [45, 27, 33], [169, 221, 176], + [255, 79, 120], [82, 133, 0], [0, 154, 46], [23, 252, 228], + [113, 85, 90], [82, 93, 130], [0, 25, 90], [150, 120, 116], + [85, 85, 88], [11, 33, 44], [30, 32, 43], [239, 191, 196], + [111, 151, 85], [111, 117, 134], [80, 29, 29], [55, 45, 0], + [116, 29, 22], [94, 179, 147], [181, 180, 0], [221, 74, 56], + [54, 61, 255], [173, 101, 82], [102, 53, 175], [131, 107, 186], + [152, 170, 127], [70, 72, 54], [50, 44, 62], [124, 185, 186], + [91, 105, 101], [112, 125, 61], [122, 0, 29], [110, 70, 54], + [68, 58, 56], [174, 129, 255], [72, 144, 121], [137, 115, 52], + [0, 144, 135], [218, 113, 60], [54, 22, 24], [255, 111, 1], + [0, 102, 121], [55, 14, 119], [75, 58, 131], [201, 226, 230], + [196, 65, 112], [255, 69, 38], [115, 190, 84], [196, 223, 114], + [173, 255, 96], [0, 68, 125], [220, 206, 201], [189, 148, 121], + [101, 110, 91], [236, 82, 0], [255, 110, 194], [122, 97, 126], + [221, 174, 162], [119, 131, 127], [165, 51, 39], [96, 142, 255], + [181, 153, 215], [165, 1, 73], [78, 0, 37], [201, 177, 169], + [3, 145, 154], [27, 42, 37], [229, 0, 241], [152, 46, 11], + [182, 113, 128], [224, 88, 89], [0, 96, 57], [87, 143, 155], + [48, 82, 48], [206, 147, 76], [179, 194, 190], [192, 186, 192], + [181, 6, 211], [23, 12, 16], [76, 83, 79], [34, 68, 81], + [62, 65, 65], [120, 114, 109], [182, 96, 43], [32, 4, 65], + [221, 181, 136], [73, 114, 0], [197, 170, 182], [3, 60, 97], + [113, 178, 245], [169, 224, 136], [73, 121, 176], [162, 195, 223], + [120, 65, 73], [45, 43, 23], [62, 14, 47], [87, 52, 76], + [0, 145, 190], [228, 81, 209], [75, 75, 106], [92, 1, 26], + [124, 128, 96], [255, 148, 145], [76, 50, 93], [0, 92, 139], + [229, 253, 164], [104, 209, 182], [3, 38, 65], [20, 0, 35], + [134, 131, 169], [207, 255, 0], [167, 44, 62], [52, 71, 90], + [177, 187, 154], [180, 160, 79], [141, 145, 142], [161, 104, 166], + [129, 61, 58], [66, 82, 24], [218, 131, 134], [119, 97, 51], + [86, 57, 48], [132, 152, 174], [144, 193, 211], [181, 102, 107], + [155, 88, 94], [133, 100, 101], [173, 124, 144], [226, 188, 0], + [227, 170, 224], [178, 194, 254], [253, 0, 57], [0, 155, 117], + [255, 244, 109], [232, 126, 172], [223, 227, 230], [132, 133, 144], + [170, 146, 151], [131, 161, 147], [87, 121, 119], [62, 113, 88], + [198, 66, 137], [234, 0, 114], [196, 168, 203], [85, 200, 153], + [231, 143, 207], [0, 69, 71], [246, 226, 227], [150, 103, 22], + [55, 143, 219], [67, 94, 106], [218, 0, 4], [27, 0, 15], + [91, 156, 143], [110, 43, 82], [1, 17, 21], [227, 232, 196], + [174, 59, 133], [234, 28, 169], [255, 158, 107], [69, 125, 139], + [146, 103, 139], [0, 205, 187], [156, 204, 4], [0, 46, 56], + [150, 197, 127], [207, 246, 180], [73, 40, 24], [118, 110, 82], + [32, 55, 14], [227, 209, 159], [46, 60, 48], [178, 234, 206], + [243, 189, 164], [162, 78, 61], [151, 111, 217], [140, 159, 168], + [124, 43, 115], [78, 95, 55], [93, 84, 98], [144, 149, 111], + [106, 167, 118], [219, 203, 246], [218, 113, 255], [152, 124, 149], + [82, 50, 60], [187, 60, 66], [88, 77, 57], [79, 193, 95], + [162, 185, 193], [121, 219, 33], [29, 89, 88], [189, 116, 78], + [22, 11, 0], [32, 34, 26], [107, 130, 149], [0, 224, 228], + [16, 36, 1], [27, 120, 42], [218, 169, 181], [176, 65, 93], + [133, 146, 83], [151, 160, 148], [6, 227, 196], [71, 104, 140], + [124, 103, 85], [7, 92, 0], [117, 96, 213], [125, 159, 0], + [195, 109, 150], [77, 145, 62], [95, 66, 118], [252, 228, 200], + [48, 48, 82], [79, 56, 27], [229, 165, 50], [112, 102, 144], + [170, 154, 146], [35, 115, 99], [115, 1, 62], [255, 144, 121], + [167, 154, 116], [2, 155, 219], [255, 1, 105], [199, 210, 231], + [202, 136, 105], [128, 255, 205], [187, 31, 105], [144, 176, 171], + [125, 116, 169], [252, 199, 219], [153, 55, 91], [0, 171, 77], + [171, 174, 209], [190, 157, 145], [230, 229, 167], [51, 44, 34], + [221, 88, 123], [245, 255, 247], [93, 48, 51], [109, 56, 0], + [255, 0, 32], [181, 123, 179], [215, 255, 230], [197, 53, 169], + [38, 0, 9], [106, 135, 129], [168, 171, 180], [212, 82, 98], + [121, 75, 97], [70, 33, 178], [141, 164, 219], [199, 200, 144], + [111, 233, 173], [162, 67, 167], [178, 176, 129], [24, 27, 0], + [40, 97, 84], [76, 164, 59], [106, 149, 115], [168, 68, 29], + [92, 114, 123], [115, 134, 113], [208, 207, 203], [137, 123, 119], + [31, 63, 34], [65, 69, 167], [218, 152, 148], [161, 117, 122], + [99, 36, 60], [173, 170, 255], [0, 205, 226], [221, 188, 98], + [105, 142, 177], [32, 132, 98], [0, 183, 224], [97, 74, 68], + [155, 187, 87], [122, 92, 84], [133, 122, 80], [118, 107, 126], + [1, 72, 51], [255, 131, 71], [122, 142, 186], [39, 71, 64], + [148, 100, 68], [235, 216, 230], [100, 98, 65], [55, 57, 23], + [106, 212, 80], [129, 129, 123], [212, 153, 227], [151, 148, 64], + [1, 26, 18], [82, 101, 84], [181, 136, 92], [164, 153, 165], + [3, 173, 137], [179, 0, 139], [227, 196, 181], [150, 83, 31], + [134, 113, 117], [116, 86, 158], [97, 125, 159], [231, 4, 82], + [6, 126, 175], [166, 151, 182], [183, 135, 168], [156, 255, 147], + [49, 29, 25], [58, 148, 89], [110, 116, 110], [176, 197, 174], + [132, 237, 247], [237, 52, 136], [117, 76, 120], [56, 70, 68], + [199, 132, 123], [0, 182, 197], [127, 166, 112], [193, 175, 158], + [42, 127, 255], [114, 165, 140], [255, 192, 127], [157, 235, 221], + [217, 124, 142], [126, 124, 147], [98, 230, 116], [181, 99, 158], + [255, 168, 97], [194, 165, 128], [141, 156, 131], [183, 5, 70], + [55, 43, 46], [0, 152, 255], [152, 89, 117], [32, 32, 76], + [255, 108, 96], [68, 80, 131], [133, 2, 170], [114, 54, 31], + [150, 118, 163], [72, 68, 73], [206, 214, 194], [59, 22, 74], + [204, 167, 99], [44, 127, 119], [2, 34, 123], [163, 126, 111], + [205, 230, 220], [205, 255, 251], [190, 129, 26], [247, 113, 131], + [237, 230, 226], [205, 198, 180], [255, 224, 158], [58, 114, 113], + [255, 123, 89], [78, 78, 1], [74, 198, 132], [139, 200, 145], + [188, 138, 150], [207, 99, 83], [220, 222, 92], [94, 170, 221], + [246, 160, 173], [226, 105, 170], [163, 218, 228], [67, 110, 131], + [0, 46, 23], [236, 251, 255], [161, 194, 182], [80, 0, 63], + [113, 105, 91], [103, 196, 187], [83, 110, 255], [93, 90, 72], + [137, 0, 57], [150, 147, 129], [55, 21, 33], [94, 70, 101], + [170, 98, 195], [141, 111, 129], [44, 97, 53], [65, 6, 1], + [86, 70, 32], [230, 144, 52], [109, 166, 189], [229, 142, 86], + [227, 166, 139], [72, 177, 118], [210, 125, 103], [181, 178, 104], + [127, 132, 39], [255, 132, 230], [67, 87, 64], [234, 228, 8], + [244, 245, 255], [50, 88, 0], [75, 107, 165], [173, 206, 255], + [155, 138, 204], [136, 81, 56], [88, 117, 193], [126, 115, 17], + [254, 165, 202], [159, 139, 91], [165, 91, 84], [137, 0, 106], + [175, 117, 111], [42, 32, 0], [116, 153, 161], [255, 181, 80], + [0, 1, 30], [209, 81, 28], [104, 129, 81], [188, 144, 138], + [120, 200, 235], [133, 2, 255], [72, 61, 48], [196, 34, 33], + [94, 167, 255], [120, 87, 21], [12, 234, 145], [255, 250, 237], + [179, 175, 157], [62, 61, 82], [90, 155, 194], [156, 47, 144], + [141, 87, 0], [173, 215, 156], [0, 118, 139], [51, 125, 0], + [197, 151, 0], [49, 86, 220], [148, 69, 117], [236, 255, 220], + [210, 76, 178], [151, 112, 60], [76, 37, 127], [158, 3, 102], + [136, 255, 236], [181, 100, 129], [57, 109, 43], [86, 115, 95], + [152, 131, 118], [155, 177, 149], [169, 121, 92], [228, 197, 211], + [159, 79, 103], [30, 43, 57], [102, 67, 39], [175, 206, 120], + [50, 46, 223], [134, 180, 135], [194, 48, 0], [171, 232, 107], + [150, 101, 109], [37, 14, 53], [166, 0, 25], [0, 128, 207], + [202, 239, 255], [50, 63, 97], [164, 73, 220], [106, 157, 59], + [255, 90, 228], [99, 106, 1], [209, 108, 218], [115, 96, 96], + [255, 186, 173], [211, 105, 180], [255, 222, 214], [108, 109, 116], + [146, 125, 94], [132, 93, 112], [91, 98, 193], [47, 74, 54], + [228, 95, 53], [255, 59, 83], [172, 132, 221], [118, 41, 136], + [112, 236, 152], [64, 133, 67], [44, 53, 51], [46, 24, 45], + [50, 57, 37], [25, 24, 27], [47, 46, 44], [2, 60, 50], + [155, 158, 226], [88, 175, 173], [92, 66, 77], [122, 197, 166], + [104, 93, 117], [185, 188, 189], [131, 67, 87], [26, 123, 66], + [46, 87, 170], [229, 81, 153], [49, 110, 71], [205, 0, 197], + [106, 0, 77], [127, 187, 236], [243, 86, 145], [215, 197, 74], + [98, 172, 183], [203, 161, 188], [162, 138, 154], [108, 63, 59], + [255, 228, 125], [220, 186, 227], [95, 129, 109], [58, 64, 74], + [125, 191, 50], [230, 236, 220], [133, 44, 25], [40, 83, 102], + [184, 203, 156], [14, 13, 0], [75, 93, 86], [107, 84, 63], + [226, 113, 114], [5, 104, 236], [46, 181, 0], [210, 22, 86], + [239, 175, 255], [104, 32, 33], [45, 32, 17], [218, 76, 255], + [112, 150, 142], [255, 123, 125], [74, 25, 48], [232, 194, 130], + [231, 219, 188], [166, 132, 134], [31, 38, 60], [54, 87, 78], + [82, 206, 121], [173, 170, 169], [138, 159, 69], [101, 66, 210], + [0, 251, 140], [93, 105, 123], [204, 210, 127], [148, 165, 161], + [121, 2, 41], [227, 131, 230], [126, 164, 193], [78, 68, 82], + [75, 44, 0], [98, 11, 112], [49, 76, 30], [135, 74, 166], + [227, 0, 145], [102, 70, 10], [235, 154, 139], [234, 195, 163], + [152, 234, 179], [171, 145, 128], [184, 85, 47], [26, 43, 47], + [148, 221, 197], [157, 140, 118], [156, 131, 51], [148, 169, 201], + [57, 41, 53], [140, 103, 94], [204, 233, 58], [145, 113, 0], + [1, 64, 11], [68, 152, 150], [28, 163, 112], [224, 141, 167], + [139, 74, 78], [102, 119, 118], [70, 146, 173], [103, 189, 168], + [105, 37, 92], [211, 191, 255], [74, 81, 50], [126, 146, 133], + [119, 115, 60], [231, 160, 204], [81, 162, 136], [44, 101, 106], + [77, 92, 94], [201, 64, 58], [221, 215, 243], [0, 88, 68], + [180, 162, 0], [72, 143, 105], [133, 129, 130], [212, 233, 185], + [61, 115, 151], [202, 232, 206], [214, 0, 52], [170, 103, 70], + [158, 85, 133], [186, 98, 0] +] + +high_contrast_arr = numpy.array(high_contrast, dtype=numpy.uint8) diff --git a/netdissect/server.py b/netdissect/server.py new file mode 100644 index 0000000000000000000000000000000000000000..d8422a2bad5ac2a09d4582a98da4f962dac1a911 --- /dev/null +++ b/netdissect/server.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python + +import argparse, connexion, os, sys, yaml, json, socket +from netdissect.easydict import EasyDict +from flask import send_from_directory, redirect +from flask_cors import CORS + + +from netdissect.serverstate import DissectionProject + +__author__ = 'Hendrik Strobelt, David Bau' + +CONFIG_FILE_NAME = 'dissect.json' +projects = {} + +app = connexion.App(__name__, debug=False) + + +def get_all_projects(): + res = [] + for key, project in projects.items(): + # print key + res.append({ + 'project': key, + 'info': { + 'layers': [layer['layer'] for layer in project.get_layers()] + } + }) + return sorted(res, key=lambda x: x['project']) + +def get_layers(project): + return { + 'request': {'project': project}, + 'res': projects[project].get_layers() + } + +def get_units(project, layer): + return { + 'request': {'project': project, 'layer': layer}, + 'res': projects[project].get_units(layer) + } + +def get_rankings(project, layer): + return { + 'request': {'project': project, 'layer': layer}, + 'res': projects[project].get_rankings(layer) + } + +def get_levels(project, layer, quantiles): + return { + 'request': {'project': project, 'layer': layer, 'quantiles': quantiles}, + 'res': projects[project].get_levels(layer, quantiles) + } + +def get_channels(project, layer): + answer = dict(channels=projects[project].get_channels(layer)) + return { + 'request': {'project': project, 'layer': layer}, + 'res': answer + } + +def post_generate(gen_req): + project = gen_req['project'] + zs = gen_req.get('zs', None) + ids = gen_req.get('ids', None) + return_urls = gen_req.get('return_urls', False) + assert (zs is None) != (ids is None) # one or the other, not both + ablations = gen_req.get('ablations', []) + interventions = gen_req.get('interventions', None) + # no z avilable if ablations + generated = projects[project].generate_images(zs, ids, interventions, + return_urls=return_urls) + return { + 'request': gen_req, + 'res': generated + } + +def post_features(feat_req): + project = feat_req['project'] + ids = feat_req['ids'] + masks = feat_req.get('masks', None) + layers = feat_req.get('layers', None) + interventions = feat_req.get('interventions', None) + features = projects[project].get_features( + ids, masks, layers, interventions) + return { + 'request': feat_req, + 'res': features + } + +def post_featuremaps(feat_req): + project = feat_req['project'] + ids = feat_req['ids'] + layers = feat_req.get('layers', None) + interventions = feat_req.get('interventions', None) + featuremaps = projects[project].get_featuremaps( + ids, layers, interventions) + return { + 'request': feat_req, + 'res': featuremaps + } + +@app.route('/client/') +def send_static(path): + """ serves all files from ./client/ to ``/client/`` + + :param path: path from api call + """ + return send_from_directory(args.client, path) + +@app.route('/data/') +def send_data(path): + """ serves all files from the data dir to ``/dissect/`` + + :param path: path from api call + """ + print('Got the data route for', path) + return send_from_directory(args.data, path) + + +@app.route('/') +def redirect_home(): + return redirect('/client/index.html', code=302) + + +def load_projects(directory): + """ + searches for CONFIG_FILE_NAME in all subdirectories of directory + and creates data handlers for all of them + + :param directory: scan directory + :return: null + """ + project_dirs = [] + # Don't search more than 2 dirs deep. + search_depth = 2 + directory.count(os.path.sep) + for root, dirs, files in os.walk(directory): + if CONFIG_FILE_NAME in files: + project_dirs.append(root) + # Don't get subprojects under a project dir. + del dirs[:] + elif root.count(os.path.sep) >= search_depth: + del dirs[:] + for p_dir in project_dirs: + print('Loading %s' % os.path.join(p_dir, CONFIG_FILE_NAME)) + with open(os.path.join(p_dir, CONFIG_FILE_NAME), 'r') as jf: + config = EasyDict(json.load(jf)) + dh_id = os.path.split(p_dir)[1] + projects[dh_id] = DissectionProject( + config=config, + project_dir=p_dir, + path_url='data/' + os.path.relpath(p_dir, directory), + public_host=args.public_host) + +app.add_api('server.yaml') + +# add CORS support +CORS(app.app, headers='Content-Type') + +parser = argparse.ArgumentParser() +parser.add_argument("--nodebug", default=False) +parser.add_argument("--address", default="127.0.0.1") # 0.0.0.0 for nonlocal use +parser.add_argument("--port", default="5001") +parser.add_argument("--public_host", default=None) +parser.add_argument("--nocache", default=False) +parser.add_argument("--data", type=str, default='dissect') +parser.add_argument("--client", type=str, default='client_dist') + +if __name__ == '__main__': + args = parser.parse_args() + for d in [args.data, args.client]: + if not os.path.isdir(d): + print('No directory %s' % d) + sys.exit(1) + args.data = os.path.abspath(args.data) + args.client = os.path.abspath(args.client) + if args.public_host is None: + args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) + app.run(port=int(args.port), debug=not args.nodebug, host=args.address, + use_reloader=False) +else: + args, _ = parser.parse_known_args() + if args.public_host is None: + args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port)) + load_projects(args.data) diff --git a/netdissect/server.yaml b/netdissect/server.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e67b9bbcb24397a21623009b4b6bf0e6d4c9193 --- /dev/null +++ b/netdissect/server.yaml @@ -0,0 +1,300 @@ +swagger: '2.0' +info: + title: Ganter API + version: "0.1" +consumes: + - application/json +produces: + - application/json + +basePath: /api + +paths: + /all_projects: + get: + tags: + - all + summary: information about all projects and sources available + operationId: netdissect.server.get_all_projects + responses: + 200: + description: return list of projects + schema: + type: array + items: + type: object + + /layers: + get: + operationId: netdissect.server.get_layers + tags: + - all + summary: returns information about all layers + parameters: + - $ref: '#/parameters/project' + responses: + 200: + description: Return requested data + schema: + type: object + + /units: + get: + operationId: netdissect.server.get_units + tags: + - all + summary: returns unit information for one layer + parameters: + + - $ref: '#/parameters/project' + - $ref: '#/parameters/layer' + + responses: + 200: + description: Return requested data + schema: + type: object + + /rankings: + get: + operationId: netdissect.server.get_rankings + tags: + - all + summary: returns ranking information for one layer + parameters: + + - $ref: '#/parameters/project' + - $ref: '#/parameters/layer' + + responses: + 200: + description: Return requested data + schema: + type: object + + /levels: + get: + operationId: netdissect.server.get_levels + tags: + - all + summary: returns feature levels for one layer + parameters: + + - $ref: '#/parameters/project' + - $ref: '#/parameters/layer' + - $ref: '#/parameters/quantiles' + + responses: + 200: + description: Return requested data + schema: + type: object + + /features: + post: + summary: calculates max feature values within a set of image locations + operationId: netdissect.server.post_features + tags: + - all + parameters: + - in: body + name: feat_req + description: RequestObject + schema: + $ref: "#/definitions/FeatureRequest" + responses: + 200: + description: returns feature vector for each layer + + /featuremaps: + post: + summary: calculates max feature values within a set of image locations + operationId: netdissect.server.post_featuremaps + tags: + - all + parameters: + - in: body + name: feat_req + description: RequestObject + schema: + $ref: "#/definitions/FeatureMapRequest" + responses: + 200: + description: returns feature vector for each layer + + /channels: + get: + operationId: netdissect.server.get_channels + tags: + - all + summary: returns channel information + parameters: + + - $ref: '#/parameters/project' + - $ref: '#/parameters/layer' + + responses: + 200: + description: Return requested data + schema: + type: object + + /generate: + post: + summary: generates images for given zs constrained by ablation + operationId: netdissect.server.post_generate + tags: + - all + parameters: + - in: body + name: gen_req + description: RequestObject + schema: + $ref: "#/definitions/GenerateRequest" + responses: + 200: + description: aaa + + +parameters: + project: + name: project + description: project ID + in: query + required: true + type: string + + layer: + name: layer + description: layer ID + in: query + type: string + default: "1" + + quantiles: + name: quantiles + in: query + type: array + items: + type: number + format: float + +definitions: + + GenerateRequest: + type: object + required: + - project + properties: + project: + type: string + zs: + type: array + items: + type: array + items: + type: number + format: float + ids: + type: array + items: + type: integer + return_urls: + type: integer + interventions: + type: array + items: + - $ref: '#/definitions/Intervention' + + FeatureRequest: + type: object + required: + - project + properties: + project: + type: string + example: 'churchoutdoor' + layers: + type: array + items: + type: string + example: [ 'layer5' ] + ids: + type: array + items: + type: integer + masks: + type: array + items: + - $ref: '#/definitions/Mask' + interventions: + type: array + items: + - $ref: '#/definitions/Intervention' + + FeatureMapRequest: + type: object + required: + - project + properties: + project: + type: string + example: 'churchoutdoor' + layers: + type: array + items: + type: string + example: [ 'layer5' ] + ids: + type: array + items: + type: integer + interventions: + type: array + items: + - $ref: '#/definitions/Intervention' + + Intervention: + type: object + properties: + maskalpha: + $ref: '#/definitions/Mask' + maskvalue: + $ref: '#/definitions/Mask' + ablations: + type: array + items: + - $ref: '#/definitions/Ablation' + + Ablation: + type: object + properties: + unit: + type: integer + alpha: + type: number + format: float + value: + type: number + format: float + layer: + type: string + + Mask: + type: object + description: 2d bitmap mask + properties: + shape: + type: array + items: + type: integer + example: [ 128, 128 ] + bitbounds: + type: array + items: + type: integer + example: [ 12, 42, 16, 46 ] + bitstring: + type: string + example: '0110111111110011' + diff --git a/netdissect/serverstate.py b/netdissect/serverstate.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ddc790c3dfc881f8aa4322d10d90e4e4fc09f0 --- /dev/null +++ b/netdissect/serverstate.py @@ -0,0 +1,526 @@ +import os, torch, numpy, base64, json, re, threading, random +from torch.utils.data import TensorDataset, DataLoader +from collections import defaultdict +from netdissect.easydict import EasyDict +from netdissect.modelconfig import create_instrumented_model +from netdissect.runningstats import RunningQuantile +from netdissect.dissection import safe_dir_name +from netdissect.zdataset import z_sample_for_model +from PIL import Image +from io import BytesIO + +class DissectionProject: + ''' + DissectionProject understand how to drive a GanTester within a + dissection project directory structure: it caches data in files, + creates image files, and translates data between plain python data + types and the pytorch-specific tensors required by GanTester. + ''' + def __init__(self, config, project_dir, path_url, public_host): + print('config done', project_dir) + self.use_cuda = torch.cuda.is_available() + self.dissect = config + self.project_dir = project_dir + self.path_url = path_url + self.public_host = public_host + self.cachedir = os.path.join(self.project_dir, 'cache') + self.tester = GanTester( + config.settings, dissectdir=project_dir, + device=torch.device('cuda') if self.use_cuda + else torch.device('cpu')) + self.stdz = [] + + def get_zs(self, size): + if size <= len(self.stdz): + return self.stdz[:size].tolist() + z_tensor = self.tester.standard_z_sample(size) + numpy_z = z_tensor.cpu().numpy() + self.stdz = numpy_z + return self.stdz.tolist() + + def get_z(self, id): + if id < len(self.stdz): + return self.stdz[id] + return self.get_zs((id + 1) * 2)[id] + + def get_zs_for_ids(self, ids): + max_id = max(ids) + if max_id >= len(self.stdz): + self.get_z(max_id) + return self.stdz[ids] + + def get_layers(self): + result = [] + layer_shapes = self.tester.layer_shapes() + for layer in self.tester.layers: + shape = layer_shapes[layer] + result.append(dict( + layer=layer, + channels=shape[1], + shape=[shape[2], shape[3]])) + return result + + def get_units(self, layer): + try: + dlayer = [dl for dl in self.dissect['layers'] + if dl['layer'] == layer][0] + except: + return None + + dunits = dlayer['units'] + result = [dict(unit=unit_num, + img='/%s/%s/s-image/%d-top.jpg' % + (self.path_url, layer, unit_num), + label=unit['iou_label']) + for unit_num, unit in enumerate(dunits)] + return result + + def get_rankings(self, layer): + try: + dlayer = [dl for dl in self.dissect['layers'] + if dl['layer'] == layer][0] + except: + return None + result = [dict(name=ranking['name'], + metric=ranking.get('metric', None), + scores=ranking['score']) + for ranking in dlayer['rankings']] + return result + + def get_levels(self, layer, quantiles): + levels = self.tester.levels( + layer, torch.from_numpy(numpy.array(quantiles))) + return levels.cpu().numpy().tolist() + + def generate_images(self, zs, ids, interventions, return_urls=False): + if ids is not None: + assert zs is None + zs = self.get_zs_for_ids(ids) + if not interventions: + # Do file caching when ids are given (and no ablations). + imgdir = os.path.join(self.cachedir, 'img', 'id') + os.makedirs(imgdir, exist_ok=True) + exist = set(os.listdir(imgdir)) + unfinished = [('%d.jpg' % id) not in exist for id in ids] + needed_z_tensor = torch.tensor(zs[unfinished]).float().to( + self.tester.device) + needed_ids = numpy.array(ids)[unfinished] + # Generate image files for just the needed images. + if len(needed_z_tensor): + imgs = self.tester.generate_images(needed_z_tensor + ).cpu().numpy() + for i, img in zip(needed_ids, imgs): + Image.fromarray(img.transpose(1, 2, 0)).save( + os.path.join(imgdir, '%d.jpg' % i), 'jpeg', + quality=99, optimize=True, progressive=True) + # Assemble a response. + imgurls = ['/%s/cache/img/id/%d.jpg' + % (self.path_url, i) for i in ids] + return [dict(id=i, d=d) for i, d in zip(ids, imgurls)] + # No file caching when ids are not given (or ablations are applied) + z_tensor = torch.tensor(zs).float().to(self.tester.device) + imgs = self.tester.generate_images(z_tensor, + intervention=decode_intervention_array(interventions, + self.tester.layer_shapes()), + ).cpu().numpy() + numpy_z = z_tensor.cpu().numpy() + if return_urls: + randdir = '%03d' % random.randrange(1000) + imgdir = os.path.join(self.cachedir, 'img', 'uniq', randdir) + os.makedirs(imgdir, exist_ok=True) + startind = random.randrange(100000) + imgurls = [] + for i, img in enumerate(imgs): + filename = '%d.jpg' % (i + startind) + Image.fromarray(img.transpose(1, 2, 0)).save( + os.path.join(imgdir, filename), 'jpeg', + quality=99, optimize=True, progressive=True) + image_url_path = ('/%s/cache/img/uniq/%s/%s' + % (self.path_url, randdir, filename)) + imgurls.append(image_url_path) + tweet_filename = 'tweet-%d.html' % (i + startind) + tweet_url_path = ('/%s/cache/img/uniq/%s/%s' + % (self.path_url, randdir, tweet_filename)) + with open(os.path.join(imgdir, tweet_filename), 'w') as f: + f.write(twitter_card(image_url_path, tweet_url_path, + self.public_host)) + return [dict(d=d) for d in imgurls] + imgurls = [img2base64(img.transpose(1, 2, 0)) for img in imgs] + return [dict(d=d) for d in imgurls] + + def get_features(self, ids, masks, layers, interventions): + zs = self.get_zs_for_ids(ids) + z_tensor = torch.tensor(zs).float().to(self.tester.device) + t_masks = torch.stack( + [torch.from_numpy(mask_to_numpy(mask)) for mask in masks] + )[:,None,:,:].to(self.tester.device) + t_features = self.tester.feature_stats(z_tensor, t_masks, + decode_intervention_array(interventions, + self.tester.layer_shapes()), layers) + # Convert torch arrays to plain python lists before returning. + return { layer: { key: value.cpu().numpy().tolist() + for key, value in feature.items() } + for layer, feature in t_features.items() } + + def get_featuremaps(self, ids, layers, interventions): + zs = self.get_zs_for_ids(ids) + z_tensor = torch.tensor(zs).float().to(self.tester.device) + # Quantilized features are returned. + q_features = self.tester.feature_maps(z_tensor, + decode_intervention_array(interventions, + self.tester.layer_shapes()), layers) + # Scale them 0-255 and return them. + # TODO: turn them into pngs for returning. + return { layer: [ + value.clamp(0, 1).mul(255).byte().cpu().numpy().tolist() + for value in valuelist ] + for layer, valuelist in q_features.items() + if (not layers) or (layer in layers) } + + def get_recipes(self): + recipedir = os.path.join(self.project_dir, 'recipe') + if not os.path.isdir(recipedir): + return [] + result = [] + for filename in os.listdir(recipedir): + with open(os.path.join(recipedir, filename)) as f: + result.append(json.load(f)) + return result + + + + +class GanTester: + ''' + GanTester holds on to a specific model to test. + + (1) loads and instantiates the GAN; + (2) instruments it at every layer so that units can be ablated + (3) precomputes z dimensionality, and output image dimensions. + ''' + def __init__(self, args, dissectdir=None, device=None): + self.cachedir = os.path.join(dissectdir, 'cache') + self.device = device if device is not None else torch.device('cpu') + self.dissectdir = dissectdir + self.modellock = threading.Lock() + + # Load the generator from the pth file. + args_copy = EasyDict(args) + args_copy.edit = True + model = create_instrumented_model(args_copy) + model.eval() + self.model = model + + # Get the set of layers of interest. + # Default: all shallow children except last. + self.layers = sorted(model.retained_features().keys()) + + # Move it to CUDA if wanted. + model.to(device) + + self.quantiles = { + layer: load_quantile_if_present(os.path.join(self.dissectdir, + safe_dir_name(layer)), 'quantiles.npz', + device=torch.device('cpu')) + for layer in self.layers } + + def layer_shapes(self): + return self.model.feature_shape + + def standard_z_sample(self, size=100, seed=1, device=None): + ''' + Generate a standard set of random Z as a (size, z_dimension) tensor. + With the same random seed, it always returns the same z (e.g., + the first one is always the same regardless of the size.) + ''' + result = z_sample_for_model(self.model, size) + if device is not None: + result = result.to(device) + return result + + def reset_intervention(self): + self.model.remove_edits() + + def apply_intervention(self, intervention): + ''' + Applies an ablation recipe of the form [(layer, unit, alpha)...]. + ''' + self.reset_intervention() + if not intervention: + return + for layer, (a, v) in intervention.items(): + self.model.edit_layer(layer, ablation=a, replacement=v) + + def generate_images(self, z_batch, intervention=None): + ''' + Makes some images. + ''' + with torch.no_grad(), self.modellock: + batch_size = 10 + self.apply_intervention(intervention) + test_loader = DataLoader(TensorDataset(z_batch[:,:,None,None]), + batch_size=batch_size, + pin_memory=('cuda' == self.device.type + and z_batch.device.type == 'cpu')) + result_img = torch.zeros( + *((len(z_batch), 3) + self.model.output_shape[2:]), + dtype=torch.uint8, device=self.device) + for batch_num, [batch_z,] in enumerate(test_loader): + batch_z = batch_z.to(self.device) + out = self.model(batch_z) + result_img[batch_num*batch_size: + batch_num*batch_size+len(batch_z)] = ( + (((out + 1) / 2) * 255).clamp(0, 255).byte()) + return result_img + + def get_layers(self): + return self.layers + + def feature_stats(self, z_batch, + masks=None, intervention=None, layers=None): + feature_stat = defaultdict(dict) + with torch.no_grad(), self.modellock: + batch_size = 10 + self.apply_intervention(intervention) + if masks is None: + masks = torch.ones(z_batch.size(0), 1, 1, 1, + device=z_batch.device, dtype=z_batch.dtype) + else: + assert masks.shape[0] == z_batch.shape[0] + assert masks.shape[1] == 1 + test_loader = DataLoader( + TensorDataset(z_batch[:,:,None,None], masks), + batch_size=batch_size, + pin_memory=('cuda' == self.device.type + and z_batch.device.type == 'cpu')) + processed = 0 + for batch_num, [batch_z, batch_m] in enumerate(test_loader): + batch_z, batch_m = [ + d.to(self.device) for d in [batch_z, batch_m]] + # Run model but disregard output + self.model(batch_z) + processing = batch_z.shape[0] + for layer, feature in self.model.retained_features().items(): + if layers is not None: + if layer not in layers: + continue + # Compute max features touching mask + resized_max = torch.nn.functional.adaptive_max_pool2d( + batch_m, + (feature.shape[2], feature.shape[3])) + max_feature = (feature * resized_max).view( + feature.shape[0], feature.shape[1], -1 + ).max(2)[0].max(0)[0] + if 'max' not in feature_stat[layer]: + feature_stat[layer]['max'] = max_feature + else: + torch.max(feature_stat[layer]['max'], max_feature, + out=feature_stat[layer]['max']) + # Compute mean features weighted by overlap with mask + resized_mean = torch.nn.functional.adaptive_avg_pool2d( + batch_m, + (feature.shape[2], feature.shape[3])) + mean_feature = (feature * resized_mean).view( + feature.shape[0], feature.shape[1], -1 + ).sum(2).sum(0) / (resized_mean.sum() + 1e-15) + if 'mean' not in feature_stat[layer]: + feature_stat[layer]['mean'] = mean_feature + else: + feature_stat[layer]['mean'] = ( + processed * feature_mean[layer]['mean'] + + processing * mean_feature) / ( + processed + processing) + processed += processing + # After summaries are done, also compute quantile stats + for layer, stats in feature_stat.items(): + if self.quantiles.get(layer, None) is not None: + for statname in ['max', 'mean']: + stats['%s_quantile' % statname] = ( + self.quantiles[layer].normalize(stats[statname])) + return feature_stat + + def levels(self, layer, quantiles): + return self.quantiles[layer].quantiles(quantiles) + + def feature_maps(self, z_batch, intervention=None, layers=None, + quantiles=True): + feature_map = defaultdict(list) + with torch.no_grad(), self.modellock: + batch_size = 10 + self.apply_intervention(intervention) + test_loader = DataLoader( + TensorDataset(z_batch[:,:,None,None]), + batch_size=batch_size, + pin_memory=('cuda' == self.device.type + and z_batch.device.type == 'cpu')) + processed = 0 + for batch_num, [batch_z] in enumerate(test_loader): + batch_z = batch_z.to(self.device) + # Run model but disregard output + self.model(batch_z) + processing = batch_z.shape[0] + for layer, feature in self.model.retained_features().items(): + for single_featuremap in feature: + if quantiles: + feature_map[layer].append(self.quantiles[layer] + .normalize(single_featuremap)) + else: + feature_map[layer].append(single_featuremap) + return feature_map + +def load_quantile_if_present(outdir, filename, device): + filepath = os.path.join(outdir, filename) + if os.path.isfile(filepath): + data = numpy.load(filepath) + result = RunningQuantile(state=data) + result.to_(device) + return result + return None + +if __name__ == '__main__': + test_main() + +def mask_to_numpy(mask_record): + # Detect a png image mask. + bitstring = mask_record['bitstring'] + bitnumpy = None + default_shape = (256, 256) + if 'image/png;base64,' in bitstring: + bitnumpy = base642img(bitstring) + default_shape = bitnumpy.shape[:2] + # Set up results + shape = mask_record.get('shape', None) + if not shape: # None or empty [] + shape = default_shape + result = numpy.zeros(shape=shape, dtype=numpy.float32) + bitbounds = mask_record.get('bitbounds', None) + if not bitbounds: # None or empty [] + bitbounds = ([0] * len(result.shape)) + list(result.shape) + start = bitbounds[:len(result.shape)] + end = bitbounds[len(result.shape):] + if bitnumpy is not None: + if bitnumpy.shape[2] == 4: + # Mask is any nontransparent bits in the alpha channel if present + result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,3] > 0) + else: + # Or any nonwhite pixels in the red channel if no alpha. + result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,0] < 255) + return result + else: + # Or bitstring can be just ones and zeros. + indexes = start.copy() + bitindex = 0 + while True: + result[tuple(indexes)] = (bitstring[bitindex] != '0') + for ii in range(len(indexes) - 1, -1, -1): + if indexes[ii] < end[ii] - 1: + break + indexes[ii] = start[ii] + else: + assert (bitindex + 1) == len(bitstring) + return result + indexes[ii] += 1 + bitindex += 1 + +def decode_intervention_array(interventions, layer_shapes): + result = {} + for channels in [decode_intervention(intervention, layer_shapes) + for intervention in (interventions or [])]: + for layer, channel in channels.items(): + if layer not in result: + result[layer] = channel + continue + accum = result[layer] + newalpha = 1 - (1 - channel[:1]) * (1 - accum[:1]) + newvalue = (accum[1:] * accum[:1] * (1 - channel[:1]) + + channel[1:] * channel[:1]) / (newalpha + 1e-40) + accum[:1] = newalpha + accum[1:] = newvalue + return result + +def decode_intervention(intervention, layer_shapes): + # Every plane of an intervention is a solid choice of activation + # over a set of channels, with a mask applied to alpha-blended channels + # (when the mask resolution is different from the feature map, it can + # be either a max-pooled or average-pooled to the proper resolution). + # This can be reduced to a single alpha-blended featuremap. + if intervention is None: + return None + mask = intervention.get('mask', None) + if mask: + mask = torch.from_numpy(mask_to_numpy(mask)) + maskpooling = intervention.get('maskpooling', 'max') + channels = {} # layer -> ([alpha, val], c) + for arec in intervention.get('ablations', []): + unit = arec['unit'] + layer = arec['layer'] + alpha = arec.get('alpha', 1.0) + if alpha is None: + alpha = 1.0 + value = arec.get('value', 0.0) + if value is None: + value = 0.0 + if alpha != 0.0 or value != 0.0: + if layer not in channels: + channels[layer] = torch.zeros(2, *layer_shapes[layer][1:]) + channels[layer][0, unit] = alpha + channels[layer][1, unit] = value + if mask is not None: + for layer in channels: + layer_shape = layer_shapes[layer][2:] + if maskpooling == 'mean': + layer_mask = torch.nn.functional.adaptive_avg_pool2d( + mask[None,None,...], layer_shape)[0] + else: + layer_mask = torch.nn.functional.adaptive_max_pool2d( + mask[None,None,...], layer_shape)[0] + channels[layer][0] *= layer_mask + return channels + +def img2base64(imgarray, for_html=True, image_format='jpeg'): + ''' + Converts a numpy array to a jpeg base64 url + ''' + input_image_buff = BytesIO() + Image.fromarray(imgarray).save(input_image_buff, image_format, + quality=99, optimize=True, progressive=True) + res = base64.b64encode(input_image_buff.getvalue()).decode('ascii') + if for_html: + return 'data:image/' + image_format + ';base64,' + res + else: + return res + +def base642img(stringdata): + stringdata = re.sub('^(?:data:)?image/\w+;base64,', '', stringdata) + im = Image.open(BytesIO(base64.b64decode(stringdata))) + return numpy.array(im) + +def twitter_card(image_path, tweet_path, public_host): + return '''\ + + + + + + + + + + + + +
+

Painting with GANs from MIT-IBM Watson AI Lab

+

This demo lets you modify a selection of meatningful GAN units for a generated image by simply painting.

+ +

Redirecting to +GANPaint +

+
+ +'''.format( + image_path=image_path, + tweet_path=tweet_path, + public_host=public_host) diff --git a/netdissect/statedict.py b/netdissect/statedict.py new file mode 100644 index 0000000000000000000000000000000000000000..858a903b57724d9e3a17b8150beea30bdc206b97 --- /dev/null +++ b/netdissect/statedict.py @@ -0,0 +1,100 @@ +''' +Utilities for dealing with simple state dicts as npz files instead of pth files. +''' + +import torch +from collections.abc import MutableMapping, Mapping + +def load_from_numpy_dict(model, numpy_dict, prefix='', examples=None): + ''' + Loads a model from numpy_dict using load_state_dict. + Converts numpy types to torch types using the current state_dict + of the model to determine types and devices for the tensors. + Supports loading a subdict by prepending the given prefix to all keys. + ''' + if prefix: + if not prefix.endswith('.'): + prefix = prefix + '.' + numpy_dict = PrefixSubDict(numpy_dict, prefix) + if examples is None: + exampels = model.state_dict() + torch_state_dict = TorchTypeMatchingDict(numpy_dict, examples) + model.load_state_dict(torch_state_dict) + +def save_to_numpy_dict(model, numpy_dict, prefix=''): + ''' + Saves a model by copying tensors to numpy_dict. + Converts torch types to numpy types using `t.detach().cpu().numpy()`. + Supports saving a subdict by prepending the given prefix to all keys. + ''' + if prefix: + if not prefix.endswith('.'): + prefix = prefix + '.' + for k, v in model.numpy_dict().items(): + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy() + numpy_dict[prefix + k] = v + +class TorchTypeMatchingDict(Mapping): + ''' + Provides a view of a dict of numpy values as torch tensors, where the + types are converted to match the types and devices in the given + dict of examples. + ''' + def __init__(self, data, examples): + self.data = data + self.examples = examples + self.cached_data = {} + def __getitem__(self, key): + if key in self.cached_data: + return self.cached_data[key] + val = self.data[key] + if key not in self.examples: + return val + example = self.examples.get(key, None) + example_type = type(example) + if example is not None and type(val) != example_type: + if isinstance(example, torch.Tensor): + val = torch.from_numpy(val) + else: + val = example_type(val) + if isinstance(example, torch.Tensor): + val = val.to(dtype=example.dtype, device=example.device) + self.cached_data[key] = val + return val + def __iter__(self): + return self.data.keys() + def __len__(self): + return len(self.data) + +class PrefixSubDict(MutableMapping): + ''' + Provides a view of the subset of a dict where string keys begin with + the given prefix. The prefix is stripped from all keys of the view. + ''' + def __init__(self, data, prefix=''): + self.data = data + self.prefix = prefix + self._cached_keys = None + def __getitem__(self, key): + return self.data[self.prefix + key] + def __setitem__(self, key, value): + pkey = self.prefix + key + if self._cached_keys is not None and pkey not in self.data: + self._cached_keys = None + self.data[pkey] = value + def __delitem__(self, key): + pkey = self.prefix + key + if self._cached_keys is not None and pkey in self.data: + self._cached_keys = None + del self.data[pkey] + def __cached_keys(self): + if self._cached_keys is None: + plen = len(self.prefix) + self._cached_keys = list(k[plen:] for k in self.data + if k.startswith(self.prefix)) + return self._cached_keys + def __iter__(self): + return iter(self.__cached_keys()) + def __len__(self): + return len(self.__cached_keys()) diff --git a/netdissect/tool/allunitsample.py b/netdissect/tool/allunitsample.py new file mode 100644 index 0000000000000000000000000000000000000000..9f86e196ce63ebfcad1fcee8bd2b7358463ff3d1 --- /dev/null +++ b/netdissect/tool/allunitsample.py @@ -0,0 +1,199 @@ +''' +A simple tool to generate sample of output of a GAN, +subject to filtering, sorting, or intervention. +''' + +import torch, numpy, os, argparse, sys, shutil, errno, numbers +from PIL import Image +from torch.utils.data import TensorDataset +from netdissect.zdataset import standard_z_sample +from netdissect.progress import default_progress, verbose_progress +from netdissect.autoeval import autoimport_eval +from netdissect.workerpool import WorkerBase, WorkerPool +from netdissect.nethook import retain_layers +from netdissect.runningstats import RunningTopK + +def main(): + parser = argparse.ArgumentParser(description='GAN sample making utility') + parser.add_argument('--model', type=str, default=None, + help='constructor for the model to test') + parser.add_argument('--pthfile', type=str, default=None, + help='filename of .pth file for the model') + parser.add_argument('--outdir', type=str, default='images', + help='directory for image output') + parser.add_argument('--size', type=int, default=100, + help='number of images to output') + parser.add_argument('--test_size', type=int, default=None, + help='number of images to test') + parser.add_argument('--layer', type=str, default=None, + help='layer to inspect') + parser.add_argument('--seed', type=int, default=1, + help='seed') + parser.add_argument('--quiet', action='store_true', default=False, + help='silences console output') + if len(sys.argv) == 1: + parser.print_usage(sys.stderr) + sys.exit(1) + args = parser.parse_args() + verbose_progress(not args.quiet) + + # Instantiate the model + model = autoimport_eval(args.model) + if args.pthfile is not None: + data = torch.load(args.pthfile) + if 'state_dict' in data: + meta = {} + for key in data: + if isinstance(data[key], numbers.Number): + meta[key] = data[key] + data = data['state_dict'] + model.load_state_dict(data) + # Unwrap any DataParallel-wrapped model + if isinstance(model, torch.nn.DataParallel): + model = next(model.children()) + # Examine first conv in model to determine input feature size. + first_layer = [c for c in model.modules() + if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, + torch.nn.Linear))][0] + # 4d input if convolutional, 2d input if first layer is linear. + if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + z_channels = first_layer.in_channels + spatialdims = (1, 1) + else: + z_channels = first_layer.in_features + spatialdims = () + # Instrument the model + retain_layers(model, [args.layer]) + model.cuda() + + if args.test_size is None: + args.test_size = args.size * 20 + z_universe = standard_z_sample(args.test_size, z_channels, + seed=args.seed) + z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims) + indexes = get_all_highest_znums( + model, z_universe, args.size, seed=args.seed) + save_chosen_unit_images(args.outdir, model, z_universe, indexes, + lightbox=True) + + +def get_all_highest_znums(model, z_universe, size, + batch_size=10, seed=1): + # The model should have been instrumented already + retained_items = list(model.retained.items()) + assert len(retained_items) == 1 + layer = retained_items[0][0] + # By default, a 10% sample + progress = default_progress() + num_units = None + with torch.no_grad(): + # Pass 1: collect max activation stats + z_loader = torch.utils.data.DataLoader(TensorDataset(z_universe), + batch_size=batch_size, num_workers=2, + pin_memory=True) + rtk = RunningTopK(k=size) + for [z] in progress(z_loader, desc='Finding max activations'): + z = z.cuda() + model(z) + feature = model.retained[layer] + num_units = feature.shape[1] + max_feature = feature.view( + feature.shape[0], num_units, -1).max(2)[0] + rtk.add(max_feature) + td, ti = rtk.result() + highest = ti.sort(1)[0] + return highest + +def save_chosen_unit_images(dirname, model, z_universe, indices, + shared_dir="shared_images", + unitdir_template="unit_{}", + name_template="image_{}.jpg", + lightbox=False, batch_size=50, seed=1): + all_indices = torch.unique(indices.view(-1), sorted=True) + z_sample = z_universe[all_indices] + progress = default_progress() + sdir = os.path.join(dirname, shared_dir) + created_hashdirs = set() + for index in range(len(z_universe)): + hd = hashdir(index) + if hd not in created_hashdirs: + created_hashdirs.add(hd) + os.makedirs(os.path.join(sdir, hd), exist_ok=True) + with torch.no_grad(): + # Pass 2: now generate images + z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample), + batch_size=batch_size, num_workers=2, + pin_memory=True) + saver = WorkerPool(SaveImageWorker) + for batch_num, [z] in enumerate(progress(z_loader, + desc='Saving images')): + z = z.cuda() + start_index = batch_num * batch_size + im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute( + 0, 2, 3, 1).cpu() + for i in range(len(im)): + index = all_indices[i + start_index].item() + filename = os.path.join(sdir, hashdir(index), + name_template.format(index)) + saver.add(im[i].numpy(), filename) + saver.join() + linker = WorkerPool(MakeLinkWorker) + for u in progress(range(len(indices)), desc='Making links'): + udir = os.path.join(dirname, unitdir_template.format(u)) + os.makedirs(udir, exist_ok=True) + for r in range(indices.shape[1]): + index = indices[u,r].item() + fn = name_template.format(index) + # sourcename = os.path.join('..', shared_dir, fn) + sourcename = os.path.join(sdir, hashdir(index), fn) + targname = os.path.join(udir, fn) + linker.add(sourcename, targname) + if lightbox: + copy_lightbox_to(udir) + linker.join() + +def copy_lightbox_to(dirname): + srcdir = os.path.realpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) + shutil.copy(os.path.join(srcdir, 'lightbox.html'), + os.path.join(dirname, '+lightbox.html')) + +def hashdir(index): + # To keep the number of files the shared directory lower, split it + # into 100 subdirectories named as follows. + return '%02d' % (index % 100) + +class SaveImageWorker(WorkerBase): + # Saving images can be sped up by sending jpeg encoding and + # file-writing work to a pool. + def work(self, data, filename): + Image.fromarray(data).save(filename, optimize=True, quality=100) + +class MakeLinkWorker(WorkerBase): + # Creating symbolic links is a bit slow and can be done faster + # in parallel rather than waiting for each to be created. + def work(self, sourcename, targname): + try: + os.link(sourcename, targname) + except OSError as e: + if e.errno == errno.EEXIST: + os.remove(targname) + os.link(sourcename, targname) + else: + raise + +class MakeSyminkWorker(WorkerBase): + # Creating symbolic links is a bit slow and can be done faster + # in parallel rather than waiting for each to be created. + def work(self, sourcename, targname): + try: + os.symlink(sourcename, targname) + except OSError as e: + if e.errno == errno.EEXIST: + os.remove(targname) + os.symlink(sourcename, targname) + else: + raise + +if __name__ == '__main__': + main() diff --git a/netdissect/tool/ganseg.py b/netdissect/tool/ganseg.py new file mode 100644 index 0000000000000000000000000000000000000000..e6225736d336cf75aedb8a7d7aec1229b497f6a9 --- /dev/null +++ b/netdissect/tool/ganseg.py @@ -0,0 +1,89 @@ +''' +A simple tool to generate sample of output of a GAN, +and apply semantic segmentation on the output. +''' + +import torch, numpy, os, argparse, sys, shutil +from PIL import Image +from torch.utils.data import TensorDataset +from netdissect.zdataset import standard_z_sample, z_dataset_for_model +from netdissect.progress import default_progress, verbose_progress +from netdissect.autoeval import autoimport_eval +from netdissect.workerpool import WorkerBase, WorkerPool +from netdissect.nethook import edit_layers, retain_layers +from netdissect.segviz import segment_visualization +from netdissect.segmenter import UnifiedParsingSegmenter +from scipy.io import savemat + +def main(): + parser = argparse.ArgumentParser(description='GAN output segmentation util') + parser.add_argument('--model', type=str, default= + 'netdissect.proggan.from_pth_file("' + + 'models/karras/churchoutdoor_lsun.pth")', + help='constructor for the model to test') + parser.add_argument('--outdir', type=str, default='images', + help='directory for image output') + parser.add_argument('--size', type=int, default=100, + help='number of images to output') + parser.add_argument('--seed', type=int, default=1, + help='seed') + parser.add_argument('--quiet', action='store_true', default=False, + help='silences console output') + #if len(sys.argv) == 1: + # parser.print_usage(sys.stderr) + # sys.exit(1) + args = parser.parse_args() + verbose_progress(not args.quiet) + + # Instantiate the model + model = autoimport_eval(args.model) + + # Make the standard z + z_dataset = z_dataset_for_model(model, size=args.size) + + # Make the segmenter + segmenter = UnifiedParsingSegmenter() + + # Write out text labels + labels, cats = segmenter.get_label_and_category_names() + with open(os.path.join(args.outdir, 'labels.txt'), 'w') as f: + for i, (label, cat) in enumerate(labels): + f.write('%s %s\n' % (label, cat)) + + # Move models to cuda + model.cuda() + + batch_size = 10 + progress = default_progress() + dirname = args.outdir + + with torch.no_grad(): + # Pass 2: now generate images + z_loader = torch.utils.data.DataLoader(z_dataset, + batch_size=batch_size, num_workers=2, + pin_memory=True) + for batch_num, [z] in enumerate(progress(z_loader, + desc='Saving images')): + z = z.cuda() + start_index = batch_num * batch_size + tensor_im = model(z) + byte_im = ((tensor_im + 1) / 2 * 255).clamp(0, 255).byte().permute( + 0, 2, 3, 1).cpu() + seg = segmenter.segment_batch(tensor_im) + for i in range(len(tensor_im)): + index = i + start_index + filename = os.path.join(dirname, '%d_img.jpg' % index) + Image.fromarray(byte_im[i].numpy()).save( + filename, optimize=True, quality=100) + filename = os.path.join(dirname, '%d_seg.mat' % index) + savemat(filename, dict(seg=seg[i].cpu().numpy())) + filename = os.path.join(dirname, '%d_seg.png' % index) + Image.fromarray(segment_visualization(seg[i].cpu().numpy(), + tensor_im.shape[2:])).save(filename) + srcdir = os.path.realpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) + shutil.copy(os.path.join(srcdir, 'lightbox.html'), + os.path.join(dirname, '+lightbox.html')) + +if __name__ == '__main__': + main() diff --git a/netdissect/tool/lightbox.html b/netdissect/tool/lightbox.html new file mode 100644 index 0000000000000000000000000000000000000000..fb0ebdf64766a43c9353428853be77deb5c52665 --- /dev/null +++ b/netdissect/tool/lightbox.html @@ -0,0 +1,59 @@ + + + + + + + + + + + + +
+

Images in {{ directory }}

+
+
{{ r }}
+ +
+
+ + + diff --git a/netdissect/tool/makesample.py b/netdissect/tool/makesample.py new file mode 100644 index 0000000000000000000000000000000000000000..36276267677360d8238a8dbf71e9753dcc327681 --- /dev/null +++ b/netdissect/tool/makesample.py @@ -0,0 +1,169 @@ +''' +A simple tool to generate sample of output of a GAN, +subject to filtering, sorting, or intervention. +''' + +import torch, numpy, os, argparse, numbers, sys, shutil +from PIL import Image +from torch.utils.data import TensorDataset +from netdissect.zdataset import standard_z_sample +from netdissect.progress import default_progress, verbose_progress +from netdissect.autoeval import autoimport_eval +from netdissect.workerpool import WorkerBase, WorkerPool +from netdissect.nethook import edit_layers, retain_layers + +def main(): + parser = argparse.ArgumentParser(description='GAN sample making utility') + parser.add_argument('--model', type=str, default=None, + help='constructor for the model to test') + parser.add_argument('--pthfile', type=str, default=None, + help='filename of .pth file for the model') + parser.add_argument('--outdir', type=str, default='images', + help='directory for image output') + parser.add_argument('--size', type=int, default=100, + help='number of images to output') + parser.add_argument('--test_size', type=int, default=None, + help='number of images to test') + parser.add_argument('--layer', type=str, default=None, + help='layer to inspect') + parser.add_argument('--seed', type=int, default=1, + help='seed') + parser.add_argument('--maximize_units', type=int, nargs='+', default=None, + help='units to maximize') + parser.add_argument('--ablate_units', type=int, nargs='+', default=None, + help='units to ablate') + parser.add_argument('--quiet', action='store_true', default=False, + help='silences console output') + if len(sys.argv) == 1: + parser.print_usage(sys.stderr) + sys.exit(1) + args = parser.parse_args() + verbose_progress(not args.quiet) + + # Instantiate the model + model = autoimport_eval(args.model) + if args.pthfile is not None: + data = torch.load(args.pthfile) + if 'state_dict' in data: + meta = {} + for key in data: + if isinstance(data[key], numbers.Number): + meta[key] = data[key] + data = data['state_dict'] + model.load_state_dict(data) + # Unwrap any DataParallel-wrapped model + if isinstance(model, torch.nn.DataParallel): + model = next(model.children()) + # Examine first conv in model to determine input feature size. + first_layer = [c for c in model.modules() + if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, + torch.nn.Linear))][0] + # 4d input if convolutional, 2d input if first layer is linear. + if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + z_channels = first_layer.in_channels + spatialdims = (1, 1) + else: + z_channels = first_layer.in_features + spatialdims = () + # Instrument the model if needed + if args.maximize_units is not None: + retain_layers(model, [args.layer]) + model.cuda() + + # Get the sample of z vectors + if args.maximize_units is None: + indexes = torch.arange(args.size) + z_sample = standard_z_sample(args.size, z_channels, seed=args.seed) + z_sample = z_sample.view(tuple(z_sample.shape) + spatialdims) + else: + # By default, if maximizing units, get a 'top 5%' sample. + if args.test_size is None: + args.test_size = args.size * 20 + z_universe = standard_z_sample(args.test_size, z_channels, + seed=args.seed) + z_universe = z_universe.view(tuple(z_universe.shape) + spatialdims) + indexes = get_highest_znums(model, z_universe, args.maximize_units, + args.size, seed=args.seed) + z_sample = z_universe[indexes] + + if args.ablate_units: + edit_layers(model, [args.layer]) + dims = max(2, max(args.ablate_units) + 1) # >=2 to avoid broadcast + model.ablation[args.layer] = torch.zeros(dims) + model.ablation[args.layer][args.ablate_units] = 1 + + save_znum_images(args.outdir, model, z_sample, indexes, + args.layer, args.ablate_units) + copy_lightbox_to(args.outdir) + + +def get_highest_znums(model, z_universe, max_units, size, + batch_size=100, seed=1): + # The model should have been instrumented already + retained_items = list(model.retained.items()) + assert len(retained_items) == 1 + layer = retained_items[0][0] + # By default, a 10% sample + progress = default_progress() + num_units = None + with torch.no_grad(): + # Pass 1: collect max activation stats + z_loader = torch.utils.data.DataLoader(TensorDataset(z_universe), + batch_size=batch_size, num_workers=2, + pin_memory=True) + scores = [] + for [z] in progress(z_loader, desc='Finding max activations'): + z = z.cuda() + model(z) + feature = model.retained[layer] + num_units = feature.shape[1] + max_feature = feature[:, max_units, ...].view( + feature.shape[0], len(max_units), -1).max(2)[0] + total_feature = max_feature.sum(1) + scores.append(total_feature.cpu()) + scores = torch.cat(scores, 0) + highest = (-scores).sort(0)[1][:size].sort(0)[0] + return highest + + +def save_znum_images(dirname, model, z_sample, indexes, layer, ablated_units, + name_template="image_{}.png", lightbox=False, batch_size=100, seed=1): + progress = default_progress() + os.makedirs(dirname, exist_ok=True) + with torch.no_grad(): + # Pass 2: now generate images + z_loader = torch.utils.data.DataLoader(TensorDataset(z_sample), + batch_size=batch_size, num_workers=2, + pin_memory=True) + saver = WorkerPool(SaveImageWorker) + if ablated_units is not None: + dims = max(2, max(ablated_units) + 1) # >=2 to avoid broadcast + mask = torch.zeros(dims) + mask[ablated_units] = 1 + model.ablation[layer] = mask[None,:,None,None].cuda() + for batch_num, [z] in enumerate(progress(z_loader, + desc='Saving images')): + z = z.cuda() + start_index = batch_num * batch_size + im = ((model(z) + 1) / 2 * 255).clamp(0, 255).byte().permute( + 0, 2, 3, 1).cpu() + for i in range(len(im)): + index = i + start_index + if indexes is not None: + index = indexes[index].item() + filename = os.path.join(dirname, name_template.format(index)) + saver.add(im[i].numpy(), filename) + saver.join() + +def copy_lightbox_to(dirname): + srcdir = os.path.realpath( + os.path.join(os.getcwd(), os.path.dirname(__file__))) + shutil.copy(os.path.join(srcdir, 'lightbox.html'), + os.path.join(dirname, '+lightbox.html')) + +class SaveImageWorker(WorkerBase): + def work(self, data, filename): + Image.fromarray(data).save(filename, optimize=True, quality=100) + +if __name__ == '__main__': + main() diff --git a/netdissect/upsegmodel/__init__.py b/netdissect/upsegmodel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76b40a0a36bc2976f185dbdc344c5a7c09b65920 --- /dev/null +++ b/netdissect/upsegmodel/__init__.py @@ -0,0 +1 @@ +from .models import ModelBuilder, SegmentationModule diff --git a/netdissect/upsegmodel/models.py b/netdissect/upsegmodel/models.py new file mode 100644 index 0000000000000000000000000000000000000000..de0a9add41016631957c52c4a441e4eccf96f903 --- /dev/null +++ b/netdissect/upsegmodel/models.py @@ -0,0 +1,441 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from . import resnet, resnext +try: + from lib.nn import SynchronizedBatchNorm2d +except ImportError: + from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d + + +class SegmentationModuleBase(nn.Module): + def __init__(self): + super(SegmentationModuleBase, self).__init__() + + @staticmethod + def pixel_acc(pred, label, ignore_index=-1): + _, preds = torch.max(pred, dim=1) + valid = (label != ignore_index).long() + acc_sum = torch.sum(valid * (preds == label).long()) + pixel_sum = torch.sum(valid) + acc = acc_sum.float() / (pixel_sum.float() + 1e-10) + return acc + + @staticmethod + def part_pixel_acc(pred_part, gt_seg_part, gt_seg_object, object_label, valid): + mask_object = (gt_seg_object == object_label) + _, pred = torch.max(pred_part, dim=1) + acc_sum = mask_object * (pred == gt_seg_part) + acc_sum = torch.sum(acc_sum.view(acc_sum.size(0), -1), dim=1) + acc_sum = torch.sum(acc_sum * valid) + pixel_sum = torch.sum(mask_object.view(mask_object.size(0), -1), dim=1) + pixel_sum = torch.sum(pixel_sum * valid) + return acc_sum, pixel_sum + + @staticmethod + def part_loss(pred_part, gt_seg_part, gt_seg_object, object_label, valid): + mask_object = (gt_seg_object == object_label) + loss = F.nll_loss(pred_part, gt_seg_part * mask_object.long(), reduction='none') + loss = loss * mask_object.float() + loss = torch.sum(loss.view(loss.size(0), -1), dim=1) + nr_pixel = torch.sum(mask_object.view(mask_object.shape[0], -1), dim=1) + sum_pixel = (nr_pixel * valid).sum() + loss = (loss * valid.float()).sum() / torch.clamp(sum_pixel, 1).float() + return loss + + +class SegmentationModule(SegmentationModuleBase): + def __init__(self, net_enc, net_dec, labeldata, loss_scale=None): + super(SegmentationModule, self).__init__() + self.encoder = net_enc + self.decoder = net_dec + self.crit_dict = nn.ModuleDict() + if loss_scale is None: + self.loss_scale = {"object": 1, "part": 0.5, "scene": 0.25, "material": 1} + else: + self.loss_scale = loss_scale + + # criterion + self.crit_dict["object"] = nn.NLLLoss(ignore_index=0) # ignore background 0 + self.crit_dict["material"] = nn.NLLLoss(ignore_index=0) # ignore background 0 + self.crit_dict["scene"] = nn.NLLLoss(ignore_index=-1) # ignore unlabelled -1 + + # Label data - read from json + self.labeldata = labeldata + object_to_num = {k: v for v, k in enumerate(labeldata['object'])} + part_to_num = {k: v for v, k in enumerate(labeldata['part'])} + self.object_part = {object_to_num[k]: + [part_to_num[p] for p in v] + for k, v in labeldata['object_part'].items()} + self.object_with_part = sorted(self.object_part.keys()) + self.decoder.object_part = self.object_part + self.decoder.object_with_part = self.object_with_part + + def forward(self, feed_dict, *, seg_size=None): + if seg_size is None: # training + + if feed_dict['source_idx'] == 0: + output_switch = {"object": True, "part": True, "scene": True, "material": False} + elif feed_dict['source_idx'] == 1: + output_switch = {"object": False, "part": False, "scene": False, "material": True} + else: + raise ValueError + + pred = self.decoder( + self.encoder(feed_dict['img'], return_feature_maps=True), + output_switch=output_switch + ) + + # loss + loss_dict = {} + if pred['object'] is not None: # object + loss_dict['object'] = self.crit_dict['object'](pred['object'], feed_dict['seg_object']) + if pred['part'] is not None: # part + part_loss = 0 + for idx_part, object_label in enumerate(self.object_with_part): + part_loss += self.part_loss( + pred['part'][idx_part], feed_dict['seg_part'], + feed_dict['seg_object'], object_label, feed_dict['valid_part'][:, idx_part]) + loss_dict['part'] = part_loss + if pred['scene'] is not None: # scene + loss_dict['scene'] = self.crit_dict['scene'](pred['scene'], feed_dict['scene_label']) + if pred['material'] is not None: # material + loss_dict['material'] = self.crit_dict['material'](pred['material'], feed_dict['seg_material']) + loss_dict['total'] = sum([loss_dict[k] * self.loss_scale[k] for k in loss_dict.keys()]) + + # metric + metric_dict= {} + if pred['object'] is not None: + metric_dict['object'] = self.pixel_acc( + pred['object'], feed_dict['seg_object'], ignore_index=0) + if pred['material'] is not None: + metric_dict['material'] = self.pixel_acc( + pred['material'], feed_dict['seg_material'], ignore_index=0) + if pred['part'] is not None: + acc_sum, pixel_sum = 0, 0 + for idx_part, object_label in enumerate(self.object_with_part): + acc, pixel = self.part_pixel_acc( + pred['part'][idx_part], feed_dict['seg_part'], feed_dict['seg_object'], + object_label, feed_dict['valid_part'][:, idx_part]) + acc_sum += acc + pixel_sum += pixel + metric_dict['part'] = acc_sum.float() / (pixel_sum.float() + 1e-10) + if pred['scene'] is not None: + metric_dict['scene'] = self.pixel_acc( + pred['scene'], feed_dict['scene_label'], ignore_index=-1) + + return {'metric': metric_dict, 'loss': loss_dict} + else: # inference + output_switch = {"object": True, "part": True, "scene": True, "material": True} + pred = self.decoder(self.encoder(feed_dict['img'], return_feature_maps=True), + output_switch=output_switch, seg_size=seg_size) + return pred + + +def conv3x3(in_planes, out_planes, stride=1, has_bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=has_bias) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + return nn.Sequential( + conv3x3(in_planes, out_planes, stride), + SynchronizedBatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +class ModelBuilder: + def __init__(self): + pass + + # custom weights initialization + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + #elif classname.find('Linear') != -1: + # m.weight.data.normal_(0.0, 0.0001) + + def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''): + pretrained = True if len(weights) == 0 else False + if arch == 'resnet34': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet34_dilated8': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=8) + elif arch == 'resnet34_dilated16': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, + dilate_scale=16) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet101': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnext101': + orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnext) # we can still use class Resnet + else: + raise Exception('Architecture undefined!') + + # net_encoder.apply(self.weights_init) + if len(weights) > 0: + # print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_encoder + + def build_decoder(self, nr_classes, + arch='ppm_bilinear_deepsup', fc_dim=512, + weights='', use_softmax=False): + if arch == 'upernet_lite': + net_decoder = UPerNet( + nr_classes=nr_classes, + fc_dim=fc_dim, + use_softmax=use_softmax, + fpn_dim=256) + elif arch == 'upernet': + net_decoder = UPerNet( + nr_classes=nr_classes, + fc_dim=fc_dim, + use_softmax=use_softmax, + fpn_dim=512) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(self.weights_init) + if len(weights) > 0: + # print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_decoder + + +class Resnet(nn.Module): + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + + +# upernet +class UPerNet(nn.Module): + def __init__(self, nr_classes, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6), + fpn_inplanes=(256,512,1024,2048), fpn_dim=256): + # Lazy import so that compilation isn't needed if not being used. + from .prroi_pool import PrRoIPool2D + super(UPerNet, self).__init__() + self.use_softmax = use_softmax + + # PPM Module + self.ppm_pooling = [] + self.ppm_conv = [] + + for scale in pool_scales: + # we use the feature map size instead of input image size, so down_scale = 1.0 + self.ppm_pooling.append(PrRoIPool2D(scale, scale, 1.)) + self.ppm_conv.append(nn.Sequential( + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + SynchronizedBatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm_pooling = nn.ModuleList(self.ppm_pooling) + self.ppm_conv = nn.ModuleList(self.ppm_conv) + self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) + + # FPN Module + self.fpn_in = [] + for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer + self.fpn_in.append(nn.Sequential( + nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), + SynchronizedBatchNorm2d(fpn_dim), + nn.ReLU(inplace=True) + )) + self.fpn_in = nn.ModuleList(self.fpn_in) + + self.fpn_out = [] + for i in range(len(fpn_inplanes) - 1): # skip the top layer + self.fpn_out.append(nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + )) + self.fpn_out = nn.ModuleList(self.fpn_out) + + self.conv_fusion = conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1) + + # background included. if ignore in loss, output channel 0 will not be trained. + self.nr_scene_class, self.nr_object_class, self.nr_part_class, self.nr_material_class = \ + nr_classes['scene'], nr_classes['object'], nr_classes['part'], nr_classes['material'] + + # input: PPM out, input_dim: fpn_dim + self.scene_head = nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(fpn_dim, self.nr_scene_class, kernel_size=1, bias=True) + ) + + # input: Fusion out, input_dim: fpn_dim + self.object_head = nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + nn.Conv2d(fpn_dim, self.nr_object_class, kernel_size=1, bias=True) + ) + + # input: Fusion out, input_dim: fpn_dim + self.part_head = nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + nn.Conv2d(fpn_dim, self.nr_part_class, kernel_size=1, bias=True) + ) + + # input: FPN_2 (P2), input_dim: fpn_dim + self.material_head = nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + nn.Conv2d(fpn_dim, self.nr_material_class, kernel_size=1, bias=True) + ) + + def forward(self, conv_out, output_switch=None, seg_size=None): + + output_dict = {k: None for k in output_switch.keys()} + + conv5 = conv_out[-1] + input_size = conv5.size() + ppm_out = [conv5] + roi = [] # fake rois, just used for pooling + for i in range(input_size[0]): # batch size + roi.append(torch.Tensor([i, 0, 0, input_size[3], input_size[2]]).view(1, -1)) # b, x0, y0, x1, y1 + roi = torch.cat(roi, dim=0).type_as(conv5) + ppm_out = [conv5] + for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): + ppm_out.append(pool_conv(F.interpolate( + pool_scale(conv5, roi.detach()), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False))) + ppm_out = torch.cat(ppm_out, 1) + f = self.ppm_last_conv(ppm_out) + + if output_switch['scene']: # scene + output_dict['scene'] = self.scene_head(f) + + if output_switch['object'] or output_switch['part'] or output_switch['material']: + fpn_feature_list = [f] + for i in reversed(range(len(conv_out) - 1)): + conv_x = conv_out[i] + conv_x = self.fpn_in[i](conv_x) # lateral branch + + f = F.interpolate( + f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch + f = conv_x + f + + fpn_feature_list.append(self.fpn_out[i](f)) + fpn_feature_list.reverse() # [P2 - P5] + + # material + if output_switch['material']: + output_dict['material'] = self.material_head(fpn_feature_list[0]) + + if output_switch['object'] or output_switch['part']: + output_size = fpn_feature_list[0].size()[2:] + fusion_list = [fpn_feature_list[0]] + for i in range(1, len(fpn_feature_list)): + fusion_list.append(F.interpolate( + fpn_feature_list[i], + output_size, + mode='bilinear', align_corners=False)) + fusion_out = torch.cat(fusion_list, 1) + x = self.conv_fusion(fusion_out) + + if output_switch['object']: # object + output_dict['object'] = self.object_head(x) + if output_switch['part']: + output_dict['part'] = self.part_head(x) + + if self.use_softmax: # is True during inference + # inference scene + x = output_dict['scene'] + x = x.squeeze(3).squeeze(2) + x = F.softmax(x, dim=1) + output_dict['scene'] = x + + # inference object, material + for k in ['object', 'material']: + x = output_dict[k] + x = F.interpolate(x, size=seg_size, mode='bilinear', align_corners=False) + x = F.softmax(x, dim=1) + output_dict[k] = x + + # inference part + x = output_dict['part'] + x = F.interpolate(x, size=seg_size, mode='bilinear', align_corners=False) + part_pred_list, head = [], 0 + for idx_part, object_label in enumerate(self.object_with_part): + n_part = len(self.object_part[object_label]) + _x = F.interpolate(x[:, head: head + n_part], size=seg_size, mode='bilinear', align_corners=False) + _x = F.softmax(_x, dim=1) + part_pred_list.append(_x) + head += n_part + output_dict['part'] = part_pred_list + + else: # Training + # object, scene, material + for k in ['object', 'scene', 'material']: + if output_dict[k] is None: + continue + x = output_dict[k] + x = F.log_softmax(x, dim=1) + if k == "scene": # for scene + x = x.squeeze(3).squeeze(2) + output_dict[k] = x + if output_dict['part'] is not None: + part_pred_list, head = [], 0 + for idx_part, object_label in enumerate(self.object_with_part): + n_part = len(self.object_part[object_label]) + x = output_dict['part'][:, head: head + n_part] + x = F.log_softmax(x, dim=1) + part_pred_list.append(x) + head += n_part + output_dict['part'] = part_pred_list + + return output_dict diff --git a/netdissect/upsegmodel/prroi_pool/.gitignore b/netdissect/upsegmodel/prroi_pool/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..18495eade007bad45f5ca772771d99f91e441e50 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/.gitignore @@ -0,0 +1,2 @@ +*.o +/_prroi_pooling diff --git a/netdissect/upsegmodel/prroi_pool/README.md b/netdissect/upsegmodel/prroi_pool/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bb98946d3b48a2069a58f179eb6da63e009c3849 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/README.md @@ -0,0 +1,66 @@ +# PreciseRoIPooling +This repo implements the **Precise RoI Pooling** (PrRoI Pooling), proposed in the paper **Acquisition of Localization Confidence for Accurate Object Detection** published at ECCV 2018 (Oral Presentation). + +**Acquisition of Localization Confidence for Accurate Object Detection** + +_Borui Jiang*, Ruixuan Luo*, Jiayuan Mao*, Tete Xiao, Yuning Jiang_ (* indicates equal contribution.) + +https://arxiv.org/abs/1807.11590 + +## Brief + +In short, Precise RoI Pooling is an integration-based (bilinear interpolation) average pooling method for RoI Pooling. It avoids any quantization and has a continuous gradient on bounding box coordinates. It is: + +- different from the original RoI Pooling proposed in [Fast R-CNN](https://arxiv.org/abs/1504.08083). PrRoI Pooling uses average pooling instead of max pooling for each bin and has a continuous gradient on bounding box coordinates. That is, one can take the derivatives of some loss function w.r.t the coordinates of each RoI and optimize the RoI coordinates. +- different from the RoI Align proposed in [Mask R-CNN](https://arxiv.org/abs/1703.06870). PrRoI Pooling uses a full integration-based average pooling instead of sampling a constant number of points. This makes the gradient w.r.t. the coordinates continuous. + +For a better illustration, we illustrate RoI Pooling, RoI Align and PrRoI Pooing in the following figure. More details including the gradient computation can be found in our paper. + +
+ +## Implementation + +PrRoI Pooling was originally implemented by [Tete Xiao](http://tetexiao.com/) based on MegBrain, an (internal) deep learning framework built by Megvii Inc. It was later adapted into open-source deep learning frameworks. Currently, we only support PyTorch. Unfortunately, we don't have any specific plan for the adaptation into other frameworks such as TensorFlow, but any contributions (pull requests) will be more than welcome. + +## Usage (PyTorch 1.0) + +In the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 1.0+ and only supports CUDA (CPU mode is not implemented). +Since we use PyTorch JIT for cxx/cuda code compilation, to use the module in your code, simply do: + +``` +from prroi_pool import PrRoIPool2D + +avg_pool = PrRoIPool2D(window_height, window_width, spatial_scale) +roi_features = avg_pool(features, rois) + +# for those who want to use the "functional" + +from prroi_pool.functional import prroi_pool2d +roi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale) +``` + + +## Usage (PyTorch 0.4) + +**!!! Please first checkout to the branch pytorch0.4.** + +In the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 0.4 and only supports CUDA (CPU mode is not implemented). +To use the PrRoI Pooling module, first goto `pytorch/prroi_pool` and execute `./travis.sh` to compile the essential components (you may need `nvcc` for this step). To use the module in your code, simply do: + +``` +from prroi_pool import PrRoIPool2D + +avg_pool = PrRoIPool2D(window_height, window_width, spatial_scale) +roi_features = avg_pool(features, rois) + +# for those who want to use the "functional" + +from prroi_pool.functional import prroi_pool2d +roi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale) +``` + +Here, + +- RoI is an `m * 5` float tensor of format `(batch_index, x0, y0, x1, y1)`, following the convention in the original Caffe implementation of RoI Pooling, although in some frameworks the batch indices are provided by an integer tensor. +- `spatial_scale` is multiplied to the RoIs. For example, if your feature maps are down-sampled by a factor of 16 (w.r.t. the input image), you should use a spatial scale of `1/16`. +- The coordinates for RoI follows the [L, R) convension. That is, `(0, 0, 4, 4)` denotes a box of size `4x4`. diff --git a/netdissect/upsegmodel/prroi_pool/__init__.py b/netdissect/upsegmodel/prroi_pool/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c40b7a7e2bca8a0dbd28e13815f2f2ad6c4728b --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/__init__.py @@ -0,0 +1,13 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao, Tete Xiao +# Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com +# Date : 07/13/2018 +# +# This file is part of PreciseRoIPooling. +# Distributed under terms of the MIT license. +# Copyright (c) 2017 Megvii Technology Limited. + +from .prroi_pool import * + diff --git a/netdissect/upsegmodel/prroi_pool/build.py b/netdissect/upsegmodel/prroi_pool/build.py new file mode 100644 index 0000000000000000000000000000000000000000..b198790817a2d11d65d6211b011f9408d9d34270 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/build.py @@ -0,0 +1,50 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : build.py +# Author : Jiayuan Mao, Tete Xiao +# Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com +# Date : 07/13/2018 +# +# This file is part of PreciseRoIPooling. +# Distributed under terms of the MIT license. +# Copyright (c) 2017 Megvii Technology Limited. + +import os +import torch + +from torch.utils.ffi import create_extension + +headers = [] +sources = [] +defines = [] +extra_objects = [] +with_cuda = False + +if torch.cuda.is_available(): + with_cuda = True + + headers+= ['src/prroi_pooling_gpu.h'] + sources += ['src/prroi_pooling_gpu.c'] + defines += [('WITH_CUDA', None)] + + this_file = os.path.dirname(os.path.realpath(__file__)) + extra_objects_cuda = ['src/prroi_pooling_gpu_impl.cu.o'] + extra_objects_cuda = [os.path.join(this_file, fname) for fname in extra_objects_cuda] + extra_objects.extend(extra_objects_cuda) +else: + # TODO(Jiayuan Mao @ 07/13): remove this restriction after we support the cpu implementation. + raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.') + +ffi = create_extension( + '_prroi_pooling', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects +) + +if __name__ == '__main__': + ffi.build() + diff --git a/netdissect/upsegmodel/prroi_pool/functional.py b/netdissect/upsegmodel/prroi_pool/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc7a8c282e846bd633c4fdc4190c4dca3da5a6f --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/functional.py @@ -0,0 +1,70 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : functional.py +# Author : Jiayuan Mao, Tete Xiao +# Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com +# Date : 07/13/2018 +# +# This file is part of PreciseRoIPooling. +# Distributed under terms of the MIT license. +# Copyright (c) 2017 Megvii Technology Limited. + +import torch +import torch.autograd as ag + +try: + from os.path import join as pjoin, dirname + from torch.utils.cpp_extension import load as load_extension + root_dir = pjoin(dirname(__file__), 'src') + _prroi_pooling = load_extension( + '_prroi_pooling', + [pjoin(root_dir, 'prroi_pooling_gpu.c'), pjoin(root_dir, 'prroi_pooling_gpu_impl.cu')], + verbose=False + ) +except ImportError: + raise ImportError('Can not compile Precise RoI Pooling library.') + +__all__ = ['prroi_pool2d'] + + +class PrRoIPool2DFunction(ag.Function): + @staticmethod + def forward(ctx, features, rois, pooled_height, pooled_width, spatial_scale): + assert 'FloatTensor' in features.type() and 'FloatTensor' in rois.type(), \ + 'Precise RoI Pooling only takes float input, got {} for features and {} for rois.'.format(features.type(), rois.type()) + + pooled_height = int(pooled_height) + pooled_width = int(pooled_width) + spatial_scale = float(spatial_scale) + + features = features.contiguous() + rois = rois.contiguous() + params = (pooled_height, pooled_width, spatial_scale) + + if features.is_cuda: + output = _prroi_pooling.prroi_pooling_forward_cuda(features, rois, *params) + ctx.params = params + # everything here is contiguous. + ctx.save_for_backward(features, rois, output) + else: + raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.') + + return output + + @staticmethod + def backward(ctx, grad_output): + features, rois, output = ctx.saved_tensors + grad_input = grad_coor = None + + if features.requires_grad: + grad_output = grad_output.contiguous() + grad_input = _prroi_pooling.prroi_pooling_backward_cuda(features, rois, output, grad_output, *ctx.params) + if rois.requires_grad: + grad_output = grad_output.contiguous() + grad_coor = _prroi_pooling.prroi_pooling_coor_backward_cuda(features, rois, output, grad_output, *ctx.params) + + return grad_input, grad_coor, None, None, None + + +prroi_pool2d = PrRoIPool2DFunction.apply + diff --git a/netdissect/upsegmodel/prroi_pool/prroi_pool.py b/netdissect/upsegmodel/prroi_pool/prroi_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..998b2b80531058fa91ac138e79ae39c5c0174601 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/prroi_pool.py @@ -0,0 +1,28 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : prroi_pool.py +# Author : Jiayuan Mao, Tete Xiao +# Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com +# Date : 07/13/2018 +# +# This file is part of PreciseRoIPooling. +# Distributed under terms of the MIT license. +# Copyright (c) 2017 Megvii Technology Limited. + +import torch.nn as nn + +from .functional import prroi_pool2d + +__all__ = ['PrRoIPool2D'] + + +class PrRoIPool2D(nn.Module): + def __init__(self, pooled_height, pooled_width, spatial_scale): + super().__init__() + + self.pooled_height = int(pooled_height) + self.pooled_width = int(pooled_width) + self.spatial_scale = float(spatial_scale) + + def forward(self, features, rois): + return prroi_pool2d(features, rois, self.pooled_height, self.pooled_width, self.spatial_scale) diff --git a/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c new file mode 100644 index 0000000000000000000000000000000000000000..1e652963cdb76fe628d0a33bc270d2c25a0f3770 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c @@ -0,0 +1,113 @@ +/* + * File : prroi_pooling_gpu.c + * Author : Jiayuan Mao, Tete Xiao + * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com + * Date : 07/13/2018 + * + * Distributed under terms of the MIT license. + * Copyright (c) 2017 Megvii Technology Limited. + */ + +#include +#include + +#include +#include + +#include + +#include "prroi_pooling_gpu_impl.cuh" + + +at::Tensor prroi_pooling_forward_cuda(const at::Tensor &features, const at::Tensor &rois, int pooled_height, int pooled_width, float spatial_scale) { + int nr_rois = rois.size(0); + int nr_channels = features.size(1); + int height = features.size(2); + int width = features.size(3); + int top_count = nr_rois * nr_channels * pooled_height * pooled_width; + auto output = at::zeros({nr_rois, nr_channels, pooled_height, pooled_width}, features.options()); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return output; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + PrRoIPoolingForwardGpu( + stream, features.data(), rois.data(), output.data(), + nr_channels, height, width, pooled_height, pooled_width, spatial_scale, + top_count + ); + + THCudaCheck(cudaGetLastError()); + return output; +} + +at::Tensor prroi_pooling_backward_cuda( + const at::Tensor &features, const at::Tensor &rois, const at::Tensor &output, const at::Tensor &output_diff, + int pooled_height, int pooled_width, float spatial_scale) { + + auto features_diff = at::zeros_like(features); + + int nr_rois = rois.size(0); + int batch_size = features.size(0); + int nr_channels = features.size(1); + int height = features.size(2); + int width = features.size(3); + int top_count = nr_rois * nr_channels * pooled_height * pooled_width; + int bottom_count = batch_size * nr_channels * height * width; + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return features_diff; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + PrRoIPoolingBackwardGpu( + stream, + features.data(), rois.data(), output.data(), output_diff.data(), + features_diff.data(), + nr_channels, height, width, pooled_height, pooled_width, spatial_scale, + top_count, bottom_count + ); + + THCudaCheck(cudaGetLastError()); + return features_diff; +} + +at::Tensor prroi_pooling_coor_backward_cuda( + const at::Tensor &features, const at::Tensor &rois, const at::Tensor &output, const at::Tensor &output_diff, + int pooled_height, int pooled_width, float spatial_scale) { + + auto coor_diff = at::zeros_like(rois); + + int nr_rois = rois.size(0); + int nr_channels = features.size(1); + int height = features.size(2); + int width = features.size(3); + int top_count = nr_rois * nr_channels * pooled_height * pooled_width; + int bottom_count = nr_rois * 5; + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return coor_diff; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + PrRoIPoolingCoorBackwardGpu( + stream, + features.data(), rois.data(), output.data(), output_diff.data(), + coor_diff.data(), + nr_channels, height, width, pooled_height, pooled_width, spatial_scale, + top_count, bottom_count + ); + + THCudaCheck(cudaGetLastError()); + return coor_diff; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("prroi_pooling_forward_cuda", &prroi_pooling_forward_cuda, "PRRoIPooling_forward"); + m.def("prroi_pooling_backward_cuda", &prroi_pooling_backward_cuda, "PRRoIPooling_backward"); + m.def("prroi_pooling_coor_backward_cuda", &prroi_pooling_coor_backward_cuda, "PRRoIPooling_backward_coor"); +} diff --git a/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.h b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..bc9d35181dd97c355fb6a5b17bc9e82e24ef1566 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.h @@ -0,0 +1,22 @@ +/* + * File : prroi_pooling_gpu.h + * Author : Jiayuan Mao, Tete Xiao + * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com + * Date : 07/13/2018 + * + * Distributed under terms of the MIT license. + * Copyright (c) 2017 Megvii Technology Limited. + */ + +int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale); + +int prroi_pooling_backward_cuda( + THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff, + int pooled_height, int pooled_width, float spatial_scale +); + +int prroi_pooling_coor_backward_cuda( + THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff, + int pooled_height, int pooled_width, float spatial_scal +); + diff --git a/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cu b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cu new file mode 100644 index 0000000000000000000000000000000000000000..452b02055495ad721ba41b2708bccecc9b1aa2f3 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cu @@ -0,0 +1,443 @@ +/* + * File : prroi_pooling_gpu_impl.cu + * Author : Tete Xiao, Jiayuan Mao + * Email : jasonhsiao97@gmail.com + * + * Distributed under terms of the MIT license. + * Copyright (c) 2017 Megvii Technology Limited. + */ + +#include "prroi_pooling_gpu_impl.cuh" + +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +#define CUDA_POST_KERNEL_CHECK \ + do { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); \ + exit(-1); \ + } \ + } while(0) + +#define CUDA_NUM_THREADS 512 + +namespace { + +static int CUDA_NUM_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +__device__ static float PrRoIPoolingGetData(F_DEVPTR_IN data, const int h, const int w, const int height, const int width) +{ + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + float retVal = overflow ? 0.0f : data[h * width + w]; + return retVal; +} + +__device__ static float PrRoIPoolingGetCoeff(float dh, float dw){ + dw = dw > 0 ? dw : -dw; + dh = dh > 0 ? dh : -dh; + return (1.0f - dh) * (1.0f - dw); +} + +__device__ static float PrRoIPoolingSingleCoorIntegral(float s, float t, float c1, float c2) { + return 0.5 * (t * t - s * s) * c2 + (t - 0.5 * t * t - s + 0.5 * s * s) * c1; +} + +__device__ static float PrRoIPoolingInterpolation(F_DEVPTR_IN data, const float h, const float w, const int height, const int width){ + float retVal = 0.0f; + int h1 = floorf(h); + int w1 = floorf(w); + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1)); + h1 = floorf(h)+1; + w1 = floorf(w); + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1)); + h1 = floorf(h); + w1 = floorf(w)+1; + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1)); + h1 = floorf(h)+1; + w1 = floorf(w)+1; + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1)); + return retVal; +} + +__device__ static float PrRoIPoolingMatCalculation(F_DEVPTR_IN this_data, const int s_h, const int s_w, const int e_h, const int e_w, + const float y0, const float x0, const float y1, const float x1, const int h0, const int w0) +{ + float alpha, beta, lim_alpha, lim_beta, tmp; + float sum_out = 0; + + alpha = x0 - float(s_w); + beta = y0 - float(s_h); + lim_alpha = x1 - float(s_w); + lim_beta = y1 - float(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp; + + alpha = float(e_w) - x1; + lim_alpha = float(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp; + + alpha = x0 - float(s_w); + beta = float(e_h) - y1; + lim_alpha = x1 - float(s_w); + lim_beta = float(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp; + + alpha = float(e_w) - x1; + lim_alpha = float(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp; + + return sum_out; +} + +__device__ static void PrRoIPoolingDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int h, const int w, const int height, const int width, const float coeff) +{ + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + if (!overflow) + atomicAdd(diff + h * width + w, top_diff * coeff); +} + +__device__ static void PrRoIPoolingMatDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int s_h, const int s_w, const int e_h, const int e_w, + const float y0, const float x0, const float y1, const float x1, const int h0, const int w0) +{ + float alpha, beta, lim_alpha, lim_beta, tmp; + + alpha = x0 - float(s_w); + beta = y0 - float(s_h); + lim_alpha = x1 - float(s_w); + lim_beta = y1 - float(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp); + + alpha = float(e_w) - x1; + lim_alpha = float(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp); + + alpha = x0 - float(s_w); + beta = float(e_h) - y1; + lim_alpha = x1 - float(s_w); + lim_beta = float(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp); + + alpha = float(e_w) - x1; + lim_alpha = float(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) + * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp); +} + +__global__ void PrRoIPoolingForward( + const int nthreads, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_OUT top_data, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const float spatial_scale) { + + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + bottom_rois += n * 5; + int roi_batch_ind = bottom_rois[0]; + + float roi_start_w = bottom_rois[1] * spatial_scale; + float roi_start_h = bottom_rois[2] * spatial_scale; + float roi_end_w = bottom_rois[3] * spatial_scale; + float roi_end_h = bottom_rois[4] * spatial_scale; + + float roi_width = max(roi_end_w - roi_start_w, ((float)0.0)); + float roi_height = max(roi_end_h - roi_start_h, ((float)0.0)); + float bin_size_h = roi_height / static_cast(pooled_height); + float bin_size_w = roi_width / static_cast(pooled_width); + + const float *this_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + float *this_out = top_data + index; + + float win_start_w = roi_start_w + bin_size_w * pw; + float win_start_h = roi_start_h + bin_size_h * ph; + float win_end_w = win_start_w + bin_size_w; + float win_end_h = win_start_h + bin_size_h; + + float win_size = max(float(0.0), bin_size_w * bin_size_h); + if (win_size == 0) { + *this_out = 0; + return; + } + + float sum_out = 0; + + int s_w, s_h, e_w, e_h; + + s_w = floorf(win_start_w); + e_w = ceilf(win_end_w); + s_h = floorf(win_start_h); + e_h = ceilf(win_end_h); + + for (int w_iter = s_w; w_iter < e_w; ++w_iter) + for (int h_iter = s_h; h_iter < e_h; ++h_iter) + sum_out += PrRoIPoolingMatCalculation(this_data, h_iter, w_iter, h_iter + 1, w_iter + 1, + max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)), + min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)), + height, width); + *this_out = sum_out / win_size; + } +} + +__global__ void PrRoIPoolingBackward( + const int nthreads, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_IN top_diff, + F_DEVPTR_OUT bottom_diff, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const float spatial_scale) { + + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + bottom_rois += n * 5; + + int roi_batch_ind = bottom_rois[0]; + float roi_start_w = bottom_rois[1] * spatial_scale; + float roi_start_h = bottom_rois[2] * spatial_scale; + float roi_end_w = bottom_rois[3] * spatial_scale; + float roi_end_h = bottom_rois[4] * spatial_scale; + + float roi_width = max(roi_end_w - roi_start_w, (float)0); + float roi_height = max(roi_end_h - roi_start_h, (float)0); + float bin_size_h = roi_height / static_cast(pooled_height); + float bin_size_w = roi_width / static_cast(pooled_width); + + const float *this_out_grad = top_diff + index; + float *this_data_grad = bottom_diff + (roi_batch_ind * channels + c) * height * width; + + float win_start_w = roi_start_w + bin_size_w * pw; + float win_start_h = roi_start_h + bin_size_h * ph; + float win_end_w = win_start_w + bin_size_w; + float win_end_h = win_start_h + bin_size_h; + + float win_size = max(float(0.0), bin_size_w * bin_size_h); + + float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size; + + int s_w, s_h, e_w, e_h; + + s_w = floorf(win_start_w); + e_w = ceilf(win_end_w); + s_h = floorf(win_start_h); + e_h = ceilf(win_end_h); + + for (int w_iter = s_w; w_iter < e_w; ++w_iter) + for (int h_iter = s_h; h_iter < e_h; ++h_iter) + PrRoIPoolingMatDistributeDiff(this_data_grad, sum_out, h_iter, w_iter, h_iter + 1, w_iter + 1, + max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)), + min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)), + height, width); + + } +} + +__global__ void PrRoIPoolingCoorBackward( + const int nthreads, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_IN top_data, + F_DEVPTR_IN top_diff, + F_DEVPTR_OUT bottom_diff, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const float spatial_scale) { + + CUDA_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + bottom_rois += n * 5; + + int roi_batch_ind = bottom_rois[0]; + float roi_start_w = bottom_rois[1] * spatial_scale; + float roi_start_h = bottom_rois[2] * spatial_scale; + float roi_end_w = bottom_rois[3] * spatial_scale; + float roi_end_h = bottom_rois[4] * spatial_scale; + + float roi_width = max(roi_end_w - roi_start_w, (float)0); + float roi_height = max(roi_end_h - roi_start_h, (float)0); + float bin_size_h = roi_height / static_cast(pooled_height); + float bin_size_w = roi_width / static_cast(pooled_width); + + const float *this_out_grad = top_diff + index; + const float *this_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + const float *this_top_data = top_data + index; + float *this_data_grad = bottom_diff + n * 5; + + float win_start_w = roi_start_w + bin_size_w * pw; + float win_start_h = roi_start_h + bin_size_h * ph; + float win_end_w = win_start_w + bin_size_w; + float win_end_h = win_start_h + bin_size_h; + + float win_size = max(float(0.0), bin_size_w * bin_size_h); + + float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size; + + // WARNING: to be discussed + if (sum_out == 0) + return; + + int s_w, s_h, e_w, e_h; + + s_w = floorf(win_start_w); + e_w = ceilf(win_end_w); + s_h = floorf(win_start_h); + e_h = ceilf(win_end_h); + + float g_x1_y = 0, g_x2_y = 0, g_x_y1 = 0, g_x_y2 = 0; + for (int h_iter = s_h; h_iter < e_h; ++h_iter) { + g_x1_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter, + min(win_end_h, float(h_iter + 1)) - h_iter, + PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height, width), + PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w, height, width)); + + g_x2_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter, + min(win_end_h, float(h_iter + 1)) - h_iter, + PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height, width), + PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w, height, width)); + } + + for (int w_iter = s_w; w_iter < e_w; ++w_iter) { + g_x_y1 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter, + min(win_end_w, float(w_iter + 1)) - w_iter, + PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height, width), + PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1, height, width)); + + g_x_y2 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter, + min(win_end_w, float(w_iter + 1)) - w_iter, + PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height, width), + PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1, height, width)); + } + + float partial_x1 = -g_x1_y + (win_end_h - win_start_h) * (*this_top_data); + float partial_y1 = -g_x_y1 + (win_end_w - win_start_w) * (*this_top_data); + float partial_x2 = g_x2_y - (win_end_h - win_start_h) * (*this_top_data); + float partial_y2 = g_x_y2 - (win_end_w - win_start_w) * (*this_top_data); + + partial_x1 = partial_x1 / win_size * spatial_scale; + partial_x2 = partial_x2 / win_size * spatial_scale; + partial_y1 = partial_y1 / win_size * spatial_scale; + partial_y2 = partial_y2 / win_size * spatial_scale; + + // (b, x1, y1, x2, y2) + + this_data_grad[0] = 0; + atomicAdd(this_data_grad + 1, (partial_x1 * (1.0 - float(pw) / pooled_width) + partial_x2 * (1.0 - float(pw + 1) / pooled_width)) + * (*this_out_grad)); + atomicAdd(this_data_grad + 2, (partial_y1 * (1.0 - float(ph) / pooled_height) + partial_y2 * (1.0 - float(ph + 1) / pooled_height)) + * (*this_out_grad)); + atomicAdd(this_data_grad + 3, (partial_x2 * float(pw + 1) / pooled_width + partial_x1 * float(pw) / pooled_width) + * (*this_out_grad)); + atomicAdd(this_data_grad + 4, (partial_y2 * float(ph + 1) / pooled_height + partial_y1 * float(ph) / pooled_height) + * (*this_out_grad)); + } +} + +} /* !anonymous namespace */ + +#ifdef __cplusplus +extern "C" { +#endif + +void PrRoIPoolingForwardGpu( + cudaStream_t stream, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_OUT top_data, + const int channels_, const int height_, const int width_, + const int pooled_height_, const int pooled_width_, + const float spatial_scale_, + const int top_count) { + + PrRoIPoolingForward<<>>( + top_count, bottom_data, bottom_rois, top_data, + channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_); + + CUDA_POST_KERNEL_CHECK; +} + +void PrRoIPoolingBackwardGpu( + cudaStream_t stream, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_IN top_data, + F_DEVPTR_IN top_diff, + F_DEVPTR_OUT bottom_diff, + const int channels_, const int height_, const int width_, + const int pooled_height_, const int pooled_width_, + const float spatial_scale_, + const int top_count, const int bottom_count) { + + cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream); + PrRoIPoolingBackward<<>>( + top_count, bottom_rois, top_diff, bottom_diff, + channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_); + CUDA_POST_KERNEL_CHECK; +} + +void PrRoIPoolingCoorBackwardGpu( + cudaStream_t stream, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_IN top_data, + F_DEVPTR_IN top_diff, + F_DEVPTR_OUT bottom_diff, + const int channels_, const int height_, const int width_, + const int pooled_height_, const int pooled_width_, + const float spatial_scale_, + const int top_count, const int bottom_count) { + + cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream); + PrRoIPoolingCoorBackward<<>>( + top_count, bottom_data, bottom_rois, top_data, top_diff, bottom_diff, + channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_); + CUDA_POST_KERNEL_CHECK; +} + +} /* !extern "C" */ + diff --git a/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cuh b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..95ad56797ca6299ededf63718d742343f8dab8e7 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cuh @@ -0,0 +1,59 @@ +/* + * File : prroi_pooling_gpu_impl.cuh + * Author : Tete Xiao, Jiayuan Mao + * Email : jasonhsiao97@gmail.com + * + * Distributed under terms of the MIT license. + * Copyright (c) 2017 Megvii Technology Limited. + */ + +#ifndef PRROI_POOLING_GPU_IMPL_CUH +#define PRROI_POOLING_GPU_IMPL_CUH + +#ifdef __cplusplus +extern "C" { +#endif + +#define F_DEVPTR_IN const float * +#define F_DEVPTR_OUT float * + +void PrRoIPoolingForwardGpu( + cudaStream_t stream, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_OUT top_data, + const int channels_, const int height_, const int width_, + const int pooled_height_, const int pooled_width_, + const float spatial_scale_, + const int top_count); + +void PrRoIPoolingBackwardGpu( + cudaStream_t stream, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_IN top_data, + F_DEVPTR_IN top_diff, + F_DEVPTR_OUT bottom_diff, + const int channels_, const int height_, const int width_, + const int pooled_height_, const int pooled_width_, + const float spatial_scale_, + const int top_count, const int bottom_count); + +void PrRoIPoolingCoorBackwardGpu( + cudaStream_t stream, + F_DEVPTR_IN bottom_data, + F_DEVPTR_IN bottom_rois, + F_DEVPTR_IN top_data, + F_DEVPTR_IN top_diff, + F_DEVPTR_OUT bottom_diff, + const int channels_, const int height_, const int width_, + const int pooled_height_, const int pooled_width_, + const float spatial_scale_, + const int top_count, const int bottom_count); + +#ifdef __cplusplus +} /* !extern "C" */ +#endif + +#endif /* !PRROI_POOLING_GPU_IMPL_CUH */ + diff --git a/netdissect/upsegmodel/prroi_pool/test_prroi_pooling2d.py b/netdissect/upsegmodel/prroi_pool/test_prroi_pooling2d.py new file mode 100644 index 0000000000000000000000000000000000000000..a29d92c80538f5550808dc51f92dcaf65cbd9fb0 --- /dev/null +++ b/netdissect/upsegmodel/prroi_pool/test_prroi_pooling2d.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_prroi_pooling2d.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 18/02/2018 +# +# This file is part of Jacinle. + +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from jactorch.utils.unittest import TorchTestCase + +from prroi_pool import PrRoIPool2D + + +class TestPrRoIPool2D(TorchTestCase): + def test_forward(self): + pool = PrRoIPool2D(7, 7, spatial_scale=0.5) + features = torch.rand((4, 16, 24, 32)).cuda() + rois = torch.tensor([ + [0, 0, 0, 14, 14], + [1, 14, 14, 28, 28], + ]).float().cuda() + + out = pool(features, rois) + out_gold = F.avg_pool2d(features, kernel_size=2, stride=1) + + self.assertTensorClose(out, torch.stack(( + out_gold[0, :, :7, :7], + out_gold[1, :, 7:14, 7:14], + ), dim=0)) + + def test_backward_shapeonly(self): + pool = PrRoIPool2D(2, 2, spatial_scale=0.5) + + features = torch.rand((4, 2, 24, 32)).cuda() + rois = torch.tensor([ + [0, 0, 0, 4, 4], + [1, 14, 14, 18, 18], + ]).float().cuda() + features.requires_grad = rois.requires_grad = True + out = pool(features, rois) + + loss = out.sum() + loss.backward() + + self.assertTupleEqual(features.size(), features.grad.size()) + self.assertTupleEqual(rois.size(), rois.grad.size()) + + +if __name__ == '__main__': + unittest.main() diff --git a/netdissect/upsegmodel/resnet.py b/netdissect/upsegmodel/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5fdf82fafa3058c5f00074d55fbb1e584d5865 --- /dev/null +++ b/netdissect/upsegmodel/resnet.py @@ -0,0 +1,235 @@ +import os +import sys +import torch +import torch.nn as nn +import math +try: + from lib.nn import SynchronizedBatchNorm2d +except ImportError: + from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + + +__all__ = ['ResNet', 'resnet50', 'resnet101'] # resnet101 is coming soon! + + +model_urls = { + 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', + 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = SynchronizedBatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = SynchronizedBatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = SynchronizedBatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = SynchronizedBatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, SynchronizedBatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + SynchronizedBatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + +''' +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet34'])) + return model +''' + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet50']), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet101']), strict=False) + return model + +# def resnet152(pretrained=False, **kwargs): +# """Constructs a ResNet-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on Places +# """ +# model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnet152'])) +# return model + +def load_url(url, model_dir='./pretrained', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) diff --git a/netdissect/upsegmodel/resnext.py b/netdissect/upsegmodel/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..4c618c9da5be17feb975833532e19474fca82dba --- /dev/null +++ b/netdissect/upsegmodel/resnext.py @@ -0,0 +1,183 @@ +import os +import sys +import torch +import torch.nn as nn +import math +try: + from lib.nn import SynchronizedBatchNorm2d +except ImportError: + from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + + +__all__ = ['ResNeXt', 'resnext101'] # support resnext 101 + + +model_urls = { + #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', + 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class GroupBottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): + super(GroupBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = SynchronizedBatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = SynchronizedBatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) + self.bn3 = SynchronizedBatchNorm2d(planes * 2) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNeXt(nn.Module): + + def __init__(self, block, layers, groups=32, num_classes=1000): + self.inplanes = 128 + super(ResNeXt, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = SynchronizedBatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = SynchronizedBatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = SynchronizedBatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) + self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) + self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) + self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(1024 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, SynchronizedBatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, groups=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + SynchronizedBatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, groups, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +''' +def resnext50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext50']), strict=False) + return model +''' + + +def resnext101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext101']), strict=False) + return model + + +# def resnext152(pretrained=False, **kwargs): +# """Constructs a ResNeXt-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on Places +# """ +# model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnext152'])) +# return model + + +def load_url(url, model_dir='./pretrained', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) diff --git a/netdissect/workerpool.py b/netdissect/workerpool.py new file mode 100644 index 0000000000000000000000000000000000000000..fe79124ddc86d0e7251d9e1a5d1012e7165249e3 --- /dev/null +++ b/netdissect/workerpool.py @@ -0,0 +1,158 @@ +''' +WorkerPool and WorkerBase for handling the common problems in managing +a multiprocess pool of workers that aren't done by multiprocessing.Pool, +including setup with per-process state, debugging by putting the worker +on the main thread, and correct handling of unexpected errors, and ctrl-C. + +To use it, +1. Put the per-process setup and the per-task work in the + setup() and work() methods of your own WorkerBase subclass. +2. To prepare the process pool, instantiate a WorkerPool, passing your + subclass type as the first (worker) argument, as well as any setup keyword + arguments. The WorkerPool will instantiate one of your workers in each + worker process (passing in the setup arguments in those processes). + If debugging, the pool can have process_count=0 to force all the work + to be done immediately on the main thread; otherwise all the work + will be passed to other processes. +3. Whenever there is a new piece of work to distribute, call pool.add(*args). + The arguments will be queued and passed as worker.work(*args) to the + next available worker. +4. When all the work has been distributed, call pool.join() to wait for all + the work to complete and to finish and terminate all the worker processes. + When pool.join() returns, all the work will have been done. + +No arrangement is made to collect the results of the work: for example, +the return value of work() is ignored. If you need to collect the +results, use your own mechanism (filesystem, shared memory object, queue) +which can be distributed using setup arguments. +''' + +from multiprocessing import Process, Queue, cpu_count +import signal +import atexit +import sys + +class WorkerBase(Process): + ''' + Subclass this class and override its work() method (and optionally, + setup() as well) to define the units of work to be done in a process + worker in a woker pool. + ''' + def __init__(self, i, process_count, queue, initargs): + if process_count > 0: + # Make sure we ignore ctrl-C if we are not on main process. + signal.signal(signal.SIGINT, signal.SIG_IGN) + self.process_id = i + self.process_count = process_count + self.queue = queue + super(WorkerBase, self).__init__() + self.setup(**initargs) + def run(self): + # Do the work until None is dequeued + while True: + try: + work_batch = self.queue.get() + except (KeyboardInterrupt, SystemExit): + print('Exiting...') + break + if work_batch is None: + self.queue.put(None) # for another worker + return + self.work(*work_batch) + def setup(self, **initargs): + ''' + Override this method for any per-process initialization. + Keywoard args are passed from WorkerPool constructor. + ''' + pass + def work(self, *args): + ''' + Override this method for one-time initialization. + Args are passed from WorkerPool.add() arguments. + ''' + raise NotImplementedError('worker subclass needed') + +class WorkerPool(object): + ''' + Instantiate this object (passing a WorkerBase subclass type + as its first argument) to create a worker pool. Then call + pool.add(*args) to queue args to distribute to worker.work(*args), + and call pool.join() to wait for all the workers to complete. + ''' + def __init__(self, worker=WorkerBase, process_count=None, **initargs): + global active_pools + if process_count is None: + process_count = cpu_count() + if process_count == 0: + # zero process_count uses only main process, for debugging. + self.queue = None + self.processes = None + self.worker = worker(None, 0, None, initargs) + return + # Ctrl-C strategy: worker processes should ignore ctrl-C. Set + # this up to be inherited by child processes before forking. + original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) + active_pools[id(self)] = self + self.queue = Queue(maxsize=(process_count * 3)) + self.processes = None # Initialize before trying to construct workers + self.processes = [worker(i, process_count, self.queue, initargs) + for i in range(process_count)] + for p in self.processes: + p.start() + # The main process should handle ctrl-C. Restore this now. + signal.signal(signal.SIGINT, original_sigint_handler) + def add(self, *work_batch): + if self.queue is None: + if hasattr(self, 'worker'): + self.worker.work(*work_batch) + else: + print('WorkerPool shutting down.', file=sys.stderr) + else: + try: + # The queue can block if the work is so slow it gets full. + self.queue.put(work_batch) + except (KeyboardInterrupt, SystemExit): + # Handle ctrl-C if done while waiting for the queue. + self.early_terminate() + def join(self): + # End the queue, and wait for all worker processes to complete nicely. + if self.queue is not None: + self.queue.put(None) + for p in self.processes: + p.join() + self.queue = None + # Remove myself from the set of pools that need cleanup on shutdown. + try: + del active_pools[id(self)] + except: + pass + def early_terminate(self): + # When shutting down unexpectedly, first end the queue. + if self.queue is not None: + try: + self.queue.put_nowait(None) # Nonblocking put throws if full. + self.queue = None + except: + pass + # But then don't wait: just forcibly terminate workers. + if self.processes is not None: + for p in self.processes: + p.terminate() + self.processes = None + try: + del active_pools[id(self)] + except: + pass + def __del__(self): + if self.queue is not None: + print('ERROR: workerpool.join() not called!', file=sys.stderr) + self.join() + +# Error and ctrl-C handling: kill worker processes if the main process ends. +active_pools = {} +def early_terminate_pools(): + for _, pool in list(active_pools.items()): + pool.early_terminate() + +atexit.register(early_terminate_pools) + diff --git a/netdissect/zdataset.py b/netdissect/zdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eb085d83d676fa1e4b1f1b053dc6f1ba2ff35381 --- /dev/null +++ b/netdissect/zdataset.py @@ -0,0 +1,41 @@ +import os, torch, numpy +from torch.utils.data import TensorDataset + +def z_dataset_for_model(model, size=100, seed=1): + return TensorDataset(z_sample_for_model(model, size, seed)) + +def z_sample_for_model(model, size=100, seed=1): + # If the model is marked with an input shape, use it. + if hasattr(model, 'input_shape'): + sample = standard_z_sample(size, model.input_shape[1], seed=seed).view( + (size,) + model.input_shape[1:]) + return sample + # Examine first conv in model to determine input feature size. + first_layer = [c for c in model.modules() + if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d, + torch.nn.Linear))][0] + # 4d input if convolutional, 2d input if first layer is linear. + if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): + sample = standard_z_sample( + size, first_layer.in_channels, seed=seed)[:,:,None,None] + else: + sample = standard_z_sample( + size, first_layer.in_features, seed=seed) + return sample + +def standard_z_sample(size, depth, seed=1, device=None): + ''' + Generate a standard set of random Z as a (size, z_dimension) tensor. + With the same random seed, it always returns the same z (e.g., + the first one is always the same regardless of the size.) + ''' + # Use numpy RandomState since it can be done deterministically + # without affecting global state + rng = numpy.random.RandomState(seed) + result = torch.from_numpy( + rng.standard_normal(size * depth) + .reshape(size, depth)).float() + if device is not None: + result = result.to(device) + return result + diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9f1eea092d5e971b5475b82ee835cec7f196bad --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa7463e44ac519f7e34bcaa76b69c034f7d4a13d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +boto3 +nltk +pillow==6.2 +fbpca +ninja +torch==1.9.0 +torchvision==0.10.0 +torchtext==0.10.0 +scikit_learn==1.0 +scikit-image==0.18.3 +tqdm +gdown \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5498289425bb70e959c0194eb7c6fab63e0c045a --- /dev/null +++ b/utils.py @@ -0,0 +1,92 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +import string +import numpy as np +from pathlib import Path +import requests +import pickle +import sys +import re +import gdown + +def prettify_name(name): + valid = "-_%s%s" % (string.ascii_letters, string.digits) + return ''.join(map(lambda c : c if c in valid else '_', name)) + +# Add padding to sequence of images +# Used in conjunction with np.hstack/np.vstack +# By default: adds one 64th of the width of horizontal padding +def pad_frames(strip, pad_fract_horiz=64, pad_fract_vert=0, pad_value=None): + dtype = strip[0].dtype + if pad_value is None: + if dtype in [np.float32, np.float64]: + pad_value = 1.0 + else: + pad_value = np.iinfo(dtype).max + + frames = [strip[0]] + for frame in strip[1:]: + if pad_fract_horiz > 0: + frames.append(pad_value*np.ones((frame.shape[0], frame.shape[1]//pad_fract_horiz, 3), dtype=dtype)) + elif pad_fract_vert > 0: + frames.append(pad_value*np.ones((frame.shape[0]//pad_fract_vert, frame.shape[1], 3), dtype=dtype)) + frames.append(frame) + return frames + + +def download_google_drive(url, output_name): + print('Downloading', url) + gdown.download(url, str(output_name)) + # session = requests.Session() + # r = session.get(url, allow_redirects=True) + # r.raise_for_status() + + # # Google Drive virus check message + # if r.encoding is not None: + # tokens = re.search('(confirm=.+)&id', str(r.content)) + # assert tokens is not None, 'Could not extract token from response' + + # url = url.replace('id=', f'{tokens[1]}&id=') + # r = session.get(url, allow_redirects=True) + # r.raise_for_status() + + # assert r.encoding is None, f'Failed to download weight file from {url}' + + # with open(output_name, 'wb') as f: + # f.write(r.content) + +def download_generic(url, output_name): + print('Downloading', url) + session = requests.Session() + r = session.get(url, allow_redirects=True) + r.raise_for_status() + + # No encoding means raw data + if r.encoding is None: + with open(output_name, 'wb') as f: + f.write(r.content) + else: + download_manual(url, output_name) + +def download_manual(url, output_name): + outpath = Path(output_name).resolve() + while not outpath.is_file(): + print('Could not find checkpoint') + print(f'Please download the checkpoint from\n{url}\nand save it as\n{outpath}') + input('Press any key to continue...') + +def download_ckpt(url, output_name): + if 'drive.google' in url: + download_google_drive(url, output_name) + elif 'mega.nz' in url: + download_manual(url, output_name) + else: + download_generic(url, output_name) \ No newline at end of file diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..433ae2ea8963c56a37e5e91932ad6d359495ed47 --- /dev/null +++ b/visualize.py @@ -0,0 +1,314 @@ +# Copyright 2020 Erik Härkönen. All rights reserved. +# This file is licensed to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy +# of the License at http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS +# OF ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Patch for broken CTRL+C handler +# https://github.com/ContinuumIO/anaconda-issues/issues/905 +import os +os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1' + +import torch, json, numpy as np +from types import SimpleNamespace +import matplotlib.pyplot as plt +from pathlib import Path +from os import makedirs +from PIL import Image +from netdissect import proggan, nethook, easydict, zdataset +from netdissect.modelconfig import create_instrumented_model +from estimators import get_estimator +from models import get_instrumented_model +from scipy.cluster.vq import kmeans +import re +import sys +import datetime +import argparse +from tqdm import trange +from config import Config +from decomposition import get_random_dirs, get_or_compute, get_max_batch_size, SEED_VISUALIZATION +from utils import pad_frames + +def x_closest(p): + distances = np.sqrt(np.sum((X - p)**2, axis=-1)) + idx = np.argmin(distances) + return distances[idx], X[idx] + +def make_gif(imgs, duration_secs, outname): + head, *tail = [Image.fromarray((x * 255).astype(np.uint8)) for x in imgs] + ms_per_frame = 1000 * duration_secs / instances + head.save(outname, format='GIF', append_images=tail, save_all=True, duration=ms_per_frame, loop=0) + +def make_mp4(imgs, duration_secs, outname): + import shutil + import subprocess as sp + + FFMPEG_BIN = shutil.which("ffmpeg") + assert FFMPEG_BIN is not None, 'ffmpeg not found, install with "conda install -c conda-forge ffmpeg"' + assert len(imgs[0].shape) == 3, 'Invalid shape of frame data' + + resolution = imgs[0].shape[0:2] + fps = int(len(imgs) / duration_secs) + + command = [ FFMPEG_BIN, + '-y', # overwrite output file + '-f', 'rawvideo', + '-vcodec','rawvideo', + '-s', f'{resolution[0]}x{resolution[1]}', # size of one frame + '-pix_fmt', 'rgb24', + '-r', f'{fps}', + '-i', '-', # imput from pipe + '-an', # no audio + '-c:v', 'libx264', + '-preset', 'slow', + '-crf', '17', + str(Path(outname).with_suffix('.mp4')) ] + + frame_data = np.concatenate([(x * 255).astype(np.uint8).reshape(-1) for x in imgs]) + with sp.Popen(command, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE) as p: + ret = p.communicate(frame_data.tobytes()) + if p.returncode != 0: + print(ret[1].decode("utf-8")) + raise sp.CalledProcessError(p.returncode, command) + + +def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_stdev, scale=1, n_rows=10, n_cols=5, make_plots=True, edit_type='latent'): + from notebooks.notebook_utils import create_strip_centered + + inst.remove_edits() + x_range = np.linspace(-scale, scale, n_cols, dtype=np.float32) # scale in sigmas + + rows = [] + for r in range(n_rows): + curr_row = [] + out_batch = create_strip_centered(inst, edit_type, layer_key, [latent], + act_comp[r], lat_comp[r], act_stdev[r], lat_stdev[r], act_mean, lat_mean, scale, 0, -1, n_cols)[0] + for i, img in enumerate(out_batch): + curr_row.append(('c{}_{:.2f}'.format(r, x_range[i]), img)) + + rows.append(curr_row[:n_cols]) + + inst.remove_edits() + + if make_plots: + # If more rows than columns, make several blocks side by side + n_blocks = 2 if n_rows > n_cols else 1 + + for r, data in enumerate(rows): + # Add white borders + imgs = pad_frames([img for _, img in data]) + + coord = ((r * n_blocks) % n_rows) + ((r * n_blocks) // n_rows) + plt.subplot(n_rows//n_blocks, n_blocks, 1 + coord) + plt.imshow(np.hstack(imgs)) + + # Custom x-axis labels + W = imgs[0].shape[1] # image width + P = imgs[1].shape[1] # padding width + locs = [(0.5*W + i*(W+P)) for i in range(n_cols)] + plt.xticks(locs, ["{:.2f}".format(v) for v in x_range]) + plt.yticks([]) + plt.ylabel(f'C{r}') + + plt.tight_layout() + plt.subplots_adjust(top=0.96) # make room for suptitle + + return [img for row in rows for img in row] + + +###################### +### Visualize results +###################### + +if __name__ == '__main__': + global max_batch, sample_shape, feature_shape, inst, args, layer_key, model + + args = Config().from_args() + t_start = datetime.datetime.now() + timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M") + print(f'[{timestamp()}] {args.model}, {args.layer}, {args.estimator}') + + # Ensure reproducibility + torch.manual_seed(0) # also sets cuda seeds + np.random.seed(0) + + # Speed up backend + torch.backends.cudnn.benchmark = True + torch.autograd.set_grad_enabled(False) + + has_gpu = torch.cuda.is_available() + device = torch.device('cuda' if has_gpu else 'cpu') + layer_key = args.layer + layer_name = layer_key #layer_key.lower().split('.')[-1] + + basedir = Path(__file__).parent.resolve() + outdir = basedir / 'out' + + # Load model + inst = get_instrumented_model(args.model, args.output_class, layer_key, device, use_w=args.use_w) + model = inst.model + feature_shape = inst.feature_shape[layer_key] + latent_shape = model.get_latent_shape() + print('Feature shape:', feature_shape) + + # Layout of activations + if len(feature_shape) != 4: # non-spatial + axis_mask = np.ones(len(feature_shape), dtype=np.int32) + else: + axis_mask = np.array([0, 1, 1, 1]) # only batch fixed => whole activation volume used + + # Shape of sample passed to PCA + sample_shape = feature_shape*axis_mask + sample_shape[sample_shape == 0] = 1 + + # Load or compute components + dump_name = get_or_compute(args, inst) + data = np.load(dump_name, allow_pickle=False) # does not contain object arrays + X_comp = data['act_comp'] + X_global_mean = data['act_mean'] + X_stdev = data['act_stdev'] + X_var_ratio = data['var_ratio'] + X_stdev_random = data['random_stdevs'] + Z_global_mean = data['lat_mean'] + Z_comp = data['lat_comp'] + Z_stdev = data['lat_stdev'] + n_comp = X_comp.shape[0] + data.close() + + # Transfer components to device + tensors = SimpleNamespace( + X_comp = torch.from_numpy(X_comp).to(device).float(), #-1, 1, C, H, W + X_global_mean = torch.from_numpy(X_global_mean).to(device).float(), # 1, C, H, W + X_stdev = torch.from_numpy(X_stdev).to(device).float(), + Z_comp = torch.from_numpy(Z_comp).to(device).float(), + Z_stdev = torch.from_numpy(Z_stdev).to(device).float(), + Z_global_mean = torch.from_numpy(Z_global_mean).to(device).float(), + ) + + transformer = get_estimator(args.estimator, n_comp, args.sparsity) + tr_param_str = transformer.get_param_str() + + # Compute max batch size given VRAM usage + max_batch = args.batch_size or (get_max_batch_size(inst, device) if has_gpu else 1) + print('Batch size:', max_batch) + + def show(): + if args.batch_mode: + plt.close('all') + else: + plt.show() + + print(f'[{timestamp()}] Creating visualizations') + + # Ensure visualization gets new samples + torch.manual_seed(SEED_VISUALIZATION) + np.random.seed(SEED_VISUALIZATION) + + # Make output directories + est_id = f'spca_{args.sparsity}' if args.estimator == 'spca' else args.estimator + outdir_comp = outdir/model.name/layer_key.lower()/est_id/'comp' + outdir_inst = outdir/model.name/layer_key.lower()/est_id/'inst' + outdir_summ = outdir/model.name/layer_key.lower()/est_id/'summ' + makedirs(outdir_comp, exist_ok=True) + makedirs(outdir_inst, exist_ok=True) + makedirs(outdir_summ, exist_ok=True) + + # Measure component sparsity (!= activation sparsity) + sparsity = np.mean(X_comp == 0) # percentage of zero values in components + print(f'Sparsity: {sparsity:.2f}') + + def get_edit_name(mode): + if mode == 'activation': + is_stylegan = 'StyleGAN' in args.model + is_w = layer_key in ['style', 'g_mapping'] + return 'W' if (is_stylegan and is_w) else 'ACT' + elif mode == 'latent': + return model.latent_space_name() + elif mode == 'both': + return 'BOTH' + else: + raise RuntimeError(f'Unknown edit mode {mode}') + + # Only visualize applicable edit modes + if args.use_w and layer_key in ['style', 'g_mapping']: + edit_modes = ['latent'] # activation edit is the same + else: + edit_modes = ['activation', 'latent'] + + # Summary grid, real components + for edit_mode in edit_modes: + plt.figure(figsize = (14,12)) + plt.suptitle(f"{args.estimator.upper()}: {model.name} - {layer_name}, {get_edit_name(edit_mode)} edit", size=16) + make_grid(tensors.Z_global_mean, tensors.Z_global_mean, tensors.Z_comp, tensors.Z_stdev, tensors.X_global_mean, + tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14) + plt.savefig(outdir_summ / f'components_{get_edit_name(edit_mode)}.jpg', dpi=300) + show() + + if args.make_video: + components = 15 + instances = 150 + + # One reasonable, one over the top + for sigma in [args.sigma, 3*args.sigma]: + for c in range(components): + for edit_mode in edit_modes: + frames = make_grid(tensors.Z_global_mean, tensors.Z_global_mean, tensors.Z_comp[c:c+1, :, :], tensors.Z_stdev[c:c+1], tensors.X_global_mean, + tensors.X_comp[c:c+1, :, :], tensors.X_stdev[c:c+1], n_rows=1, n_cols=instances, scale=sigma, make_plots=False, edit_type=edit_mode) + plt.close('all') + + frames = [x for _, x in frames] + frames = frames + frames[::-1] + make_mp4(frames, 5, outdir_comp / f'{get_edit_name(edit_mode)}_sigma{sigma}_comp{c}.mp4') + + + # Summary grid, random directions + # Using the stdevs of the principal components for same norm + random_dirs_act = torch.from_numpy(get_random_dirs(n_comp, np.prod(sample_shape)).reshape(-1, *sample_shape)).to(device) + random_dirs_z = torch.from_numpy(get_random_dirs(n_comp, np.prod(inst.input_shape)).reshape(-1, *latent_shape)).to(device) + + for edit_mode in edit_modes: + plt.figure(figsize = (14,12)) + plt.suptitle(f"{model.name} - {layer_name}, random directions w/ PC stdevs, {get_edit_name(edit_mode)} edit", size=16) + make_grid(tensors.Z_global_mean, tensors.Z_global_mean, random_dirs_z, tensors.Z_stdev, + tensors.X_global_mean, random_dirs_act, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14) + plt.savefig(outdir_summ / f'random_dirs_{get_edit_name(edit_mode)}.jpg', dpi=300) + show() + + # Random instances w/ components added + n_random_imgs = 10 + latents = model.sample_latent(n_samples=n_random_imgs) + + for img_idx in trange(n_random_imgs, desc='Random images', ascii=True): + #print(f'Creating visualizations for random image {img_idx+1}/{n_random_imgs}') + z = latents[img_idx][None, ...] + + # Summary grid, real components + for edit_mode in edit_modes: + plt.figure(figsize = (14,12)) + plt.suptitle(f"{args.estimator.upper()}: {model.name} - {layer_name}, {get_edit_name(edit_mode)} edit", size=16) + make_grid(z, tensors.Z_global_mean, tensors.Z_comp, tensors.Z_stdev, + tensors.X_global_mean, tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14) + plt.savefig(outdir_summ / f'samp{img_idx}_real_{get_edit_name(edit_mode)}.jpg', dpi=300) + show() + + if args.make_video: + components = 5 + instances = 150 + + # One reasonable, one over the top + for sigma in [args.sigma, 3*args.sigma]: #[2, 5]: + for edit_mode in edit_modes: + imgs = make_grid(z, tensors.Z_global_mean, tensors.Z_comp, tensors.Z_stdev, tensors.X_global_mean, tensors.X_comp, tensors.X_stdev, + n_rows=components, n_cols=instances, scale=sigma, make_plots=False, edit_type=edit_mode) + plt.close('all') + + for c in range(components): + frames = [x for _, x in imgs[c*instances:(c+1)*instances]] + frames = frames + frames[::-1] + make_mp4(frames, 5, outdir_inst / f'{get_edit_name(edit_mode)}_sigma{sigma}_img{img_idx}_comp{c}.mp4') + + print('Done in', datetime.datetime.now() - t_start) \ No newline at end of file