diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..eea938df748bd62dde8bdf5b9384d3dcd8675dad
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2020 Erik Härkönen
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
index d74693852da735718f39d1aad452481ef6936f01..48d92436c16afb605cad08292b2ebe1528434e9c 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
---
-title: Guccio AI Designer
-emoji: 📉
-colorFrom: green
-colorTo: purple
+title: ClothingGAN
+emoji: 👘
+colorFrom: indigo
+colorTo: gray
sdk: gradio
-sdk_version: 3.1.4
+sdk_version: 2.9.4
app_file: app.py
pinned: false
+license: cc-by-nc-3.0
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..cefb7d0a530ae0474a0725aed487e07fb4d020e4
--- /dev/null
+++ b/app.py
@@ -0,0 +1,288 @@
+from ssl import ALERT_DESCRIPTION_CLOSE_NOTIFY
+import nltk; nltk.download('wordnet')
+
+#@title Load Model
+selected_model = 'lookbook'
+
+# Load model
+from IPython.utils import io
+import torch
+import PIL
+import numpy as np
+import ipywidgets as widgets
+from PIL import Image
+import imageio
+from models import get_instrumented_model
+from decomposition import get_or_compute
+from config import Config
+from skimage import img_as_ubyte
+import gradio as gr
+import numpy as np
+from ipywidgets import fixed
+
+# Speed up computation
+torch.autograd.set_grad_enabled(False)
+torch.backends.cudnn.benchmark = True
+
+# Specify model to use
+config = Config(
+ model='StyleGAN2',
+ layer='style',
+ output_class=selected_model,
+ components=80,
+ use_w=True,
+ batch_size=5_000, # style layer quite small
+)
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+inst = get_instrumented_model(config.model, config.output_class,
+ config.layer, torch.device(device), use_w=config.use_w)
+
+path_to_components = get_or_compute(config, inst)
+
+model = inst.model
+
+comps = np.load(path_to_components)
+lst = comps.files
+latent_dirs = []
+latent_stdevs = []
+
+load_activations = False
+
+for item in lst:
+ if load_activations:
+ if item == 'act_comp':
+ for i in range(comps[item].shape[0]):
+ latent_dirs.append(comps[item][i])
+ if item == 'act_stdev':
+ for i in range(comps[item].shape[0]):
+ latent_stdevs.append(comps[item][i])
+ else:
+ if item == 'lat_comp':
+ for i in range(comps[item].shape[0]):
+ latent_dirs.append(comps[item][i])
+ if item == 'lat_stdev':
+ for i in range(comps[item].shape[0]):
+ latent_stdevs.append(comps[item][i])
+
+
+#@title Define functions
+
+
+# Taken from https://github.com/alexanderkuk/log-progress
+def log_progress(sequence, every=1, size=None, name='Items'):
+ from ipywidgets import IntProgress, HTML, VBox
+ from IPython.display import display
+
+ is_iterator = False
+ if size is None:
+ try:
+ size = len(sequence)
+ except TypeError:
+ is_iterator = True
+ if size is not None:
+ if every is None:
+ if size <= 200:
+ every = 1
+ else:
+ every = int(size / 200) # every 0.5%
+ else:
+ assert every is not None, 'sequence is iterator, set every'
+
+ if is_iterator:
+ progress = IntProgress(min=0, max=1, value=1)
+ progress.bar_style = 'info'
+ else:
+ progress = IntProgress(min=0, max=size, value=0)
+ label = HTML()
+ box = VBox(children=[label, progress])
+ display(box)
+
+ index = 0
+ try:
+ for index, record in enumerate(sequence, 1):
+ if index == 1 or index % every == 0:
+ if is_iterator:
+ label.value = '{name}: {index} / ?'.format(
+ name=name,
+ index=index
+ )
+ else:
+ progress.value = index
+ label.value = u'{name}: {index} / {size}'.format(
+ name=name,
+ index=index,
+ size=size
+ )
+ yield record
+ except:
+ progress.bar_style = 'danger'
+ raise
+ else:
+ progress.bar_style = 'success'
+ progress.value = index
+ label.value = "{name}: {index}".format(
+ name=name,
+ index=str(index or '?')
+ )
+
+def name_direction(sender):
+ if not text.value:
+ print('Please name the direction before saving')
+ return
+
+ if num in named_directions.values():
+ target_key = list(named_directions.keys())[list(named_directions.values()).index(num)]
+ print(f'Direction already named: {target_key}')
+ print(f'Overwriting... ')
+ del(named_directions[target_key])
+ named_directions[text.value] = [num, start_layer.value, end_layer.value]
+ save_direction(random_dir, text.value)
+ for item in named_directions:
+ print(item, named_directions[item])
+
+def save_direction(direction, filename):
+ filename += ".npy"
+ np.save(filename, direction, allow_pickle=True, fix_imports=True)
+ print(f'Latent direction saved as {filename}')
+
+def mix_w(w1, w2, content, style):
+ for i in range(0,5):
+ w2[i] = w1[i] * (1 - content) + w2[i] * content
+
+ for i in range(5, 16):
+ w2[i] = w1[i] * (1 - style) + w2[i] * style
+
+ return w2
+
+def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):
+ # blockPrint()
+ model.truncation = truncation
+ if w is None:
+ w = model.sample_latent(1, seed=seed).detach().cpu().numpy()
+ w = [w]*model.get_max_latents() # one per layer
+ else:
+ w = [np.expand_dims(x, 0) for x in w]
+
+ for l in range(start, end):
+ for i in range(len(directions)):
+ w[l] = w[l] + directions[i] * distances[i] * scale
+
+ torch.cuda.empty_cache()
+ #save image and display
+ out = model.sample_np(w)
+ final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)
+
+
+ if save is not None:
+ if disp == False:
+ print(save)
+ final_im.save(f'out/{seed}_{save:05}.png')
+ if disp:
+ display(final_im)
+
+ return final_im
+
+def generate_mov(seed, truncation, direction_vec, scale, layers, n_frames, out_name = 'out', noise_spec = None, loop=True):
+ """Generates a mov moving back and forth along the chosen direction vector"""
+ # Example of reading a generated set of images, and storing as MP4.
+ movieName = f'{out_name}.mp4'
+ offset = -10
+ step = 20 / n_frames
+ imgs = []
+ for i in log_progress(range(n_frames), name = "Generating frames"):
+ print(f'\r{i} / {n_frames}', end='')
+ w = model.sample_latent(1, seed=seed).cpu().numpy()
+
+ model.truncation = truncation
+ w = [w]*model.get_max_latents() # one per layer
+ for l in layers:
+ if l <= model.get_max_latents():
+ w[l] = w[l] + direction_vec * offset * scale
+
+ #save image and display
+ out = model.sample_np(w)
+ final_im = Image.fromarray((out * 255).astype(np.uint8))
+ imgs.append(out)
+ #increase offset
+ offset += step
+ if loop:
+ imgs += imgs[::-1]
+ with imageio.get_writer(movieName, mode='I') as writer:
+ for image in log_progress(list(imgs), name = "Creating animation"):
+ writer.append_data(img_as_ubyte(image))
+
+
+#@title Demo UI
+
+
+def generate_image(seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer):
+ seed1 = int(seed1)
+ seed2 = int(seed2)
+
+ scale = 1
+ params = {'c0': c0,
+ 'c1': c1,
+ 'c2': c2,
+ 'c3': c3,
+ 'c4': c4,
+ 'c5': c5,
+ 'c6': c6}
+
+ param_indexes = {'c0': 0,
+ 'c1': 1,
+ 'c2': 2,
+ 'c3': 3,
+ 'c4': 4,
+ 'c5': 5,
+ 'c6': 6}
+
+ directions = []
+ distances = []
+ for k, v in params.items():
+ directions.append(latent_dirs[param_indexes[k]])
+ distances.append(v)
+
+ w1 = model.sample_latent(1, seed=seed1).detach().cpu().numpy()
+ w1 = [w1]*model.get_max_latents() # one per layer
+ im1 = model.sample_np(w1)
+
+ w2 = model.sample_latent(1, seed=seed2).detach().cpu().numpy()
+ w2 = [w2]*model.get_max_latents() # one per layer
+ im2 = model.sample_np(w2)
+ combined_im = np.concatenate([im1, im2], axis=1)
+ input_im = Image.fromarray((combined_im * 255).astype(np.uint8))
+
+
+ mixed_w = mix_w(w1, w2, content, style)
+ return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False)
+
+truncation = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label="Truncation")
+start_layer = gr.inputs.Number(default=3, label="Start Layer")
+end_layer = gr.inputs.Number(default=14, label="End Layer")
+seed1 = gr.inputs.Number(default=0, label="Seed 1")
+seed2 = gr.inputs.Number(default=0, label="Seed 2")
+content = gr.inputs.Slider(label="Structure", minimum=0, maximum=1, default=0.5)
+style = gr.inputs.Slider(label="Style", minimum=0, maximum=1, default=0.5)
+
+slider_max_val = 20
+slider_min_val = -20
+slider_step = 1
+
+c0 = gr.inputs.Slider(label="Sleeve & Size", minimum=slider_min_val, maximum=slider_max_val, default=0)
+c1 = gr.inputs.Slider(label="Dress - Jacket", minimum=slider_min_val, maximum=slider_max_val, default=0)
+c2 = gr.inputs.Slider(label="Female Coat", minimum=slider_min_val, maximum=slider_max_val, default=0)
+c3 = gr.inputs.Slider(label="Coat", minimum=slider_min_val, maximum=slider_max_val, default=0)
+c4 = gr.inputs.Slider(label="Graphics", minimum=slider_min_val, maximum=slider_max_val, default=0)
+c5 = gr.inputs.Slider(label="Dark", minimum=slider_min_val, maximum=slider_max_val, default=0)
+c6 = gr.inputs.Slider(label="Less Cleavage", minimum=slider_min_val, maximum=slider_max_val, default=0)
+
+
+scale = 1
+
+inputs = [seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer]
+description = "Change the seed number to generate different parent design.Please give a clap/star if you find it useful :)"
+
+article="
+
+
+
diff --git a/netdissect/dissection.py b/netdissect/dissection.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eef0dfd0b8804e45eb878aca68e72f8c6493474
--- /dev/null
+++ b/netdissect/dissection.py
@@ -0,0 +1,1617 @@
+'''
+To run dissection:
+
+1. Load up the convolutional model you wish to dissect, and wrap it in
+ an InstrumentedModel; then call imodel.retain_layers([layernames,..])
+ to instrument the layers of interest.
+2. Load the segmentation dataset using the BrodenDataset class;
+ use the transform_image argument to normalize images to be
+ suitable for the model, or the size argument to truncate the dataset.
+3. Choose a directory in which to write the output, and call
+ dissect(outdir, model, dataset).
+
+Example:
+
+ from dissect import InstrumentedModel, dissect
+ from broden import BrodenDataset
+
+ model = InstrumentedModel(load_my_model())
+ model.eval()
+ model.cuda()
+ model.retain_layers(['conv1', 'conv2', 'conv3', 'conv4', 'conv5'])
+ bds = BrodenDataset('dataset/broden1_227',
+ transform_image=transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(IMAGE_MEAN, IMAGE_STDEV)]),
+ size=1000)
+ dissect('result/dissect', model, bds,
+ examples_per_unit=10)
+'''
+
+import torch, numpy, os, re, json, shutil, types, tempfile, torchvision
+# import warnings
+# warnings.simplefilter('error', UserWarning)
+from PIL import Image
+from xml.etree import ElementTree as et
+from collections import OrderedDict, defaultdict
+from .progress import verbose_progress, default_progress, print_progress
+from .progress import desc_progress
+from .runningstats import RunningQuantile, RunningTopK
+from .runningstats import RunningCrossCovariance, RunningConditionalQuantile
+from .sampler import FixedSubsetSampler
+from .actviz import activation_visualization
+from .segviz import segment_visualization, high_contrast
+from .workerpool import WorkerBase, WorkerPool
+from .segmenter import UnifiedParsingSegmenter
+
+def dissect(outdir, model, dataset,
+ segrunner=None,
+ train_dataset=None,
+ model_segmenter=None,
+ quantile_threshold=0.005,
+ iou_threshold=0.05,
+ iqr_threshold=0.01,
+ examples_per_unit=100,
+ batch_size=100,
+ num_workers=24,
+ seg_batch_size=5,
+ make_images=True,
+ make_labels=True,
+ make_maxiou=False,
+ make_covariance=False,
+ make_report=True,
+ make_row_images=True,
+ make_single_images=False,
+ rank_all_labels=False,
+ netname=None,
+ meta=None,
+ merge=None,
+ settings=None,
+ ):
+ '''
+ Runs net dissection in-memory, using pytorch, and saves visualizations
+ and metadata into outdir.
+ '''
+ assert not model.training, 'Run model.eval() before dissection'
+ if netname is None:
+ netname = type(model).__name__
+ if segrunner is None:
+ segrunner = ClassifierSegRunner(dataset)
+ if train_dataset is None:
+ train_dataset = dataset
+ make_iqr = (quantile_threshold == 'iqr')
+ with torch.no_grad():
+ device = next(model.parameters()).device
+ levels = None
+ labelnames, catnames = None, None
+ maxioudata, iqrdata = None, None
+ labeldata = None
+ iqrdata, cov = None, None
+
+ labelnames, catnames = segrunner.get_label_and_category_names()
+ label_category = [catnames.index(c) if c in catnames else 0
+ for l, c in labelnames]
+
+ # First, always collect qunatiles and topk information.
+ segloader = torch.utils.data.DataLoader(dataset,
+ batch_size=batch_size, num_workers=num_workers,
+ pin_memory=(device.type == 'cuda'))
+ quantiles, topk = collect_quantiles_and_topk(outdir, model,
+ segloader, segrunner, k=examples_per_unit)
+
+ # Thresholds can be automatically chosen by maximizing iqr
+ if make_iqr:
+ # Get thresholds based on an IQR optimization
+ segloader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=1, num_workers=num_workers,
+ pin_memory=(device.type == 'cuda'))
+ iqrdata = collect_iqr(outdir, model, segloader, segrunner)
+ max_iqr, full_iqr_levels = iqrdata[:2]
+ max_iqr_agreement = iqrdata[4]
+ # qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
+ levels = {layer: full_iqr_levels[layer][
+ max_iqr[layer].max(0)[1],
+ torch.arange(max_iqr[layer].shape[1])].to(device)
+ for layer in full_iqr_levels}
+ else:
+ levels = {k: qc.quantiles([1.0 - quantile_threshold])[:,0]
+ for k, qc in quantiles.items()}
+
+ quantiledata = (topk, quantiles, levels, quantile_threshold)
+
+ if make_images:
+ segloader = torch.utils.data.DataLoader(dataset,
+ batch_size=batch_size, num_workers=num_workers,
+ pin_memory=(device.type == 'cuda'))
+ generate_images(outdir, model, dataset, topk, levels, segrunner,
+ row_length=examples_per_unit, batch_size=seg_batch_size,
+ row_images=make_row_images,
+ single_images=make_single_images,
+ num_workers=num_workers)
+
+ if make_maxiou:
+ assert train_dataset, "Need training dataset for maxiou."
+ segloader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=1, num_workers=num_workers,
+ pin_memory=(device.type == 'cuda'))
+ maxioudata = collect_maxiou(outdir, model, segloader,
+ segrunner)
+
+ if make_labels:
+ segloader = torch.utils.data.DataLoader(dataset,
+ batch_size=1, num_workers=num_workers,
+ pin_memory=(device.type == 'cuda'))
+ iou_scores, iqr_scores, tcs, lcs, ccs, ics = (
+ collect_bincounts(outdir, model, segloader,
+ levels, segrunner))
+ labeldata = (iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold,
+ iqr_threshold)
+
+ if make_covariance:
+ segloader = torch.utils.data.DataLoader(dataset,
+ batch_size=seg_batch_size,
+ num_workers=num_workers,
+ pin_memory=(device.type == 'cuda'))
+ cov = collect_covariance(outdir, model, segloader, segrunner)
+
+ if make_report:
+ generate_report(outdir,
+ quantiledata=quantiledata,
+ labelnames=labelnames,
+ catnames=catnames,
+ labeldata=labeldata,
+ maxioudata=maxioudata,
+ iqrdata=iqrdata,
+ covariancedata=cov,
+ rank_all_labels=rank_all_labels,
+ netname=netname,
+ meta=meta,
+ mergedata=merge,
+ settings=settings)
+
+ return quantiledata, labeldata
+
+def generate_report(outdir, quantiledata, labelnames=None, catnames=None,
+ labeldata=None, maxioudata=None, iqrdata=None, covariancedata=None,
+ rank_all_labels=False, netname='Model', meta=None, settings=None,
+ mergedata=None):
+ '''
+ Creates dissection.json reports and summary bargraph.svg files in the
+ specified output directory, and copies a dissection.html interface
+ to go along with it.
+ '''
+ all_layers = []
+ # Current source code directory, for html to copy.
+ srcdir = os.path.realpath(
+ os.path.join(os.getcwd(), os.path.dirname(__file__)))
+ # Unpack arguments
+ topk, quantiles, levels, quantile_threshold = quantiledata
+ top_record = dict(
+ netname=netname,
+ meta=meta,
+ default_ranking='unit',
+ quantile_threshold=quantile_threshold)
+ if settings is not None:
+ top_record['settings'] = settings
+ if labeldata is not None:
+ iou_scores, iqr_scores, lcs, ccs, ics, iou_threshold, iqr_threshold = (
+ labeldata)
+ catorder = {'object': -7, 'scene': -6, 'part': -5,
+ 'piece': -4,
+ 'material': -3, 'texture': -2, 'color': -1}
+ for i, cat in enumerate(c for c in catnames if c not in catorder):
+ catorder[cat] = i
+ catnumber = {n: i for i, n in enumerate(catnames)}
+ catnumber['-'] = 0
+ top_record['default_ranking'] = 'label'
+ top_record['iou_threshold'] = iou_threshold
+ top_record['iqr_threshold'] = iqr_threshold
+ labelnumber = dict((name[0], num)
+ for num, name in enumerate(labelnames))
+ # Make a segmentation color dictionary
+ segcolors = {}
+ for i, name in enumerate(labelnames):
+ key = ','.join(str(s) for s in high_contrast[i % len(high_contrast)])
+ if key in segcolors:
+ segcolors[key] += '/' + name[0]
+ else:
+ segcolors[key] = name[0]
+ top_record['segcolors'] = segcolors
+ for layer in topk.keys():
+ units, rankings = [], []
+ record = dict(layer=layer, units=units, rankings=rankings)
+ # For every unit, we always have basic visualization information.
+ topa, topi = topk[layer].result()
+ lev = levels[layer]
+ for u in range(len(topa)):
+ units.append(dict(
+ unit=u,
+ interp=True,
+ level=lev[u].item(),
+ top=[dict(imgnum=i.item(), maxact=a.item())
+ for i, a in zip(topi[u], topa[u])],
+ ))
+ rankings.append(dict(name="unit", score=list([
+ u for u in range(len(topa))])))
+ # TODO: consider including stats and ranking based on quantiles,
+ # variance, connectedness here.
+
+ # if we have labeldata, then every unit also gets a bunch of other info
+ if labeldata is not None:
+ lscore, qscore, cc, ic = [dat[layer]
+ for dat in [iou_scores, iqr_scores, ccs, ics]]
+ if iqrdata is not None:
+ # If we have IQR thresholds, assign labels based on that
+ max_iqr, max_iqr_level = iqrdata[:2]
+ best_label = max_iqr[layer].max(0)[1]
+ best_score = lscore[best_label, torch.arange(lscore.shape[1])]
+ best_qscore = qscore[best_label, torch.arange(lscore.shape[1])]
+ else:
+ # Otherwise, assign labels based on max iou
+ best_score, best_label = lscore.max(0)
+ best_qscore = qscore[best_label, torch.arange(qscore.shape[1])]
+ record['iou_threshold'] = iou_threshold,
+ for u, urec in enumerate(units):
+ score, qscore, label = (
+ best_score[u], best_qscore[u], best_label[u])
+ urec.update(dict(
+ iou=score.item(),
+ iou_iqr=qscore.item(),
+ lc=lcs[label].item(),
+ cc=cc[catnumber[labelnames[label][1]], u].item(),
+ ic=ic[label, u].item(),
+ interp=(qscore.item() > iqr_threshold and
+ score.item() > iou_threshold),
+ iou_labelnum=label.item(),
+ iou_label=labelnames[label.item()][0],
+ iou_cat=labelnames[label.item()][1],
+ ))
+ if maxioudata is not None:
+ max_iou, max_iou_level, max_iou_quantile = maxioudata
+ qualified_iou = max_iou[layer].clone()
+ # qualified_iou[max_iou_quantile[layer] > 0.75] = 0
+ best_score, best_label = qualified_iou.max(0)
+ for u, urec in enumerate(units):
+ urec.update(dict(
+ maxiou=best_score[u].item(),
+ maxiou_label=labelnames[best_label[u].item()][0],
+ maxiou_cat=labelnames[best_label[u].item()][1],
+ maxiou_level=max_iou_level[layer][best_label[u], u].item(),
+ maxiou_quantile=max_iou_quantile[layer][
+ best_label[u], u].item()))
+ if iqrdata is not None:
+ [max_iqr, max_iqr_level, max_iqr_quantile,
+ max_iqr_iou, max_iqr_agreement] = iqrdata
+ qualified_iqr = max_iqr[layer].clone()
+ qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
+ best_score, best_label = qualified_iqr.max(0)
+ for u, urec in enumerate(units):
+ urec.update(dict(
+ iqr=best_score[u].item(),
+ iqr_label=labelnames[best_label[u].item()][0],
+ iqr_cat=labelnames[best_label[u].item()][1],
+ iqr_level=max_iqr_level[layer][best_label[u], u].item(),
+ iqr_quantile=max_iqr_quantile[layer][
+ best_label[u], u].item(),
+ iqr_iou=max_iqr_iou[layer][best_label[u], u].item()
+ ))
+ if covariancedata is not None:
+ score = covariancedata[layer].correlation()
+ best_score, best_label = score.max(1)
+ for u, urec in enumerate(units):
+ urec.update(dict(
+ cor=best_score[u].item(),
+ cor_label=labelnames[best_label[u].item()][0],
+ cor_cat=labelnames[best_label[u].item()][1]
+ ))
+ if mergedata is not None:
+ # Final step: if the user passed any data to merge into the
+ # units, merge them now. This can be used, for example, to
+ # indiate that a unit is not interpretable based on some
+ # outside analysis of unit statistics.
+ for lrec in mergedata.get('layers', []):
+ if lrec['layer'] == layer:
+ break
+ else:
+ lrec = None
+ for u, urec in enumerate(lrec.get('units', []) if lrec else []):
+ units[u].update(urec)
+ # After populating per-unit info, populate per-layer ranking info
+ if labeldata is not None:
+ # Collect all labeled units
+ labelunits = defaultdict(list)
+ all_labelunits = defaultdict(list)
+ for u, urec in enumerate(units):
+ if urec['interp']:
+ labelunits[urec['iou_labelnum']].append(u)
+ all_labelunits[urec['iou_labelnum']].append(u)
+ # Sort all units in order with most popular label first.
+ label_ordering = sorted(units,
+ # Sort by:
+ key=lambda r: (-1 if r['interp'] else 0, # interpretable
+ -len(labelunits[r['iou_labelnum']]), # label freq, score
+ -max([units[u]['iou']
+ for u in labelunits[r['iou_labelnum']]], default=0),
+ r['iou_labelnum'], # label
+ -r['iou'])) # unit score
+ # Add label and iou ranking.
+ rankings.append(dict(name="label", score=(numpy.argsort(list(
+ ur['unit'] for ur in label_ordering))).tolist()))
+ rankings.append(dict(name="max iou", metric="iou", score=list(
+ -ur['iou'] for ur in units)))
+ # Add ranking for top labels
+ # for labelnum in [n for n in sorted(
+ # all_labelunits.keys(), key=lambda x:
+ # -len(all_labelunits[x])) if len(all_labelunits[n])]:
+ # label = labelnames[labelnum][0]
+ # rankings.append(dict(name="%s-iou" % label,
+ # concept=label, metric='iou',
+ # score=(-lscore[labelnum, :]).tolist()))
+ # Collate labels by category then frequency.
+ record['labels'] = [dict(
+ label=labelnames[label][0],
+ labelnum=label,
+ units=labelunits[label],
+ cat=labelnames[label][1])
+ for label in (sorted(labelunits.keys(),
+ # Sort by:
+ key=lambda l: (catorder.get( # category
+ labelnames[l][1], 0),
+ -len(labelunits[l]), # label freq
+ -max([units[u]['iou'] for u in labelunits[l]],
+ default=0) # score
+ ))) if len(labelunits[label])]
+ # Total number of interpretable units.
+ record['interpretable'] = sum(len(group['units'])
+ for group in record['labels'])
+ # Make a bargraph of labels
+ os.makedirs(os.path.join(outdir, safe_dir_name(layer)),
+ exist_ok=True)
+ catgroups = OrderedDict()
+ for _, cat in sorted([(v, k) for k, v in catorder.items()]):
+ catgroups[cat] = []
+ for rec in record['labels']:
+ if rec['cat'] not in catgroups:
+ catgroups[rec['cat']] = []
+ catgroups[rec['cat']].append(rec['label'])
+ make_svg_bargraph(
+ [rec['label'] for rec in record['labels']],
+ [len(rec['units']) for rec in record['labels']],
+ [(cat, len(group)) for cat, group in catgroups.items()],
+ filename=os.path.join(outdir, safe_dir_name(layer),
+ 'bargraph.svg'))
+ # Only show the bargraph if it is non-empty.
+ if len(record['labels']):
+ record['bargraph'] = 'bargraph.svg'
+ if maxioudata is not None:
+ rankings.append(dict(name="max maxiou", metric="maxiou", score=list(
+ -ur['maxiou'] for ur in units)))
+ if iqrdata is not None:
+ rankings.append(dict(name="max iqr", metric="iqr", score=list(
+ -ur['iqr'] for ur in units)))
+ if covariancedata is not None:
+ rankings.append(dict(name="max cor", metric="cor", score=list(
+ -ur['cor'] for ur in units)))
+
+ all_layers.append(record)
+ # Now add the same rankings to every layer...
+ all_labels = None
+ if rank_all_labels:
+ all_labels = [name for name, cat in labelnames]
+ if labeldata is not None:
+ # Count layers+quadrants with a given label, and sort by freq
+ counted_labels = defaultdict(int)
+ for label in [
+ re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '', unitrec['iou_label'])
+ for record in all_layers for unitrec in record['units']]:
+ counted_labels[label] += 1
+ if all_labels is None:
+ all_labels = [label for count, label in sorted((-v, k)
+ for k, v in counted_labels.items())]
+ for record in all_layers:
+ layer = record['layer']
+ for label in all_labels:
+ labelnum = labelnumber[label]
+ record['rankings'].append(dict(name="%s-iou" % label,
+ concept=label, metric='iou',
+ score=(-iou_scores[layer][labelnum, :]).tolist()))
+
+ if maxioudata is not None:
+ if all_labels is None:
+ counted_labels = defaultdict(int)
+ for label in [
+ re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
+ unitrec['maxiou_label'])
+ for record in all_layers for unitrec in record['units']]:
+ counted_labels[label] += 1
+ all_labels = [label for count, label in sorted((-v, k)
+ for k, v in counted_labels.items())]
+ qualified_iou = max_iou[layer].clone()
+ qualified_iou[max_iou_quantile[layer] > 0.5] = 0
+ for record in all_layers:
+ layer = record['layer']
+ for label in all_labels:
+ labelnum = labelnumber[label]
+ record['rankings'].append(dict(name="%s-maxiou" % label,
+ concept=label, metric='maxiou',
+ score=(-qualified_iou[labelnum, :]).tolist()))
+
+ if iqrdata is not None:
+ if all_labels is None:
+ counted_labels = defaultdict(int)
+ for label in [
+ re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
+ unitrec['iqr_label'])
+ for record in all_layers for unitrec in record['units']]:
+ counted_labels[label] += 1
+ all_labels = [label for count, label in sorted((-v, k)
+ for k, v in counted_labels.items())]
+ # qualified_iqr[max_iqr_quantile[layer] > 0.5] = 0
+ for record in all_layers:
+ layer = record['layer']
+ qualified_iqr = max_iqr[layer].clone()
+ for label in all_labels:
+ labelnum = labelnumber[label]
+ record['rankings'].append(dict(name="%s-iqr" % label,
+ concept=label, metric='iqr',
+ score=(-qualified_iqr[labelnum, :]).tolist()))
+
+ if covariancedata is not None:
+ if all_labels is None:
+ counted_labels = defaultdict(int)
+ for label in [
+ re.sub(r'-(?:t|b|l|r|tl|tr|bl|br)$', '',
+ unitrec['cor_label'])
+ for record in all_layers for unitrec in record['units']]:
+ counted_labels[label] += 1
+ all_labels = [label for count, label in sorted((-v, k)
+ for k, v in counted_labels.items())]
+ for record in all_layers:
+ layer = record['layer']
+ score = covariancedata[layer].correlation()
+ for label in all_labels:
+ labelnum = labelnumber[label]
+ record['rankings'].append(dict(name="%s-cor" % label,
+ concept=label, metric='cor',
+ score=(-score[:, labelnum]).tolist()))
+
+ for record in all_layers:
+ layer = record['layer']
+ # Dump per-layer json inside per-layer directory
+ record['dirname'] = '.'
+ with open(os.path.join(outdir, safe_dir_name(layer), 'dissect.json'),
+ 'w') as jsonfile:
+ top_record['layers'] = [record]
+ json.dump(top_record, jsonfile, indent=1)
+ # Copy the per-layer html
+ shutil.copy(os.path.join(srcdir, 'dissect.html'),
+ os.path.join(outdir, safe_dir_name(layer), 'dissect.html'))
+ record['dirname'] = safe_dir_name(layer)
+
+ # Dump all-layer json in parent directory
+ with open(os.path.join(outdir, 'dissect.json'), 'w') as jsonfile:
+ top_record['layers'] = all_layers
+ json.dump(top_record, jsonfile, indent=1)
+ # Copy the all-layer html
+ shutil.copy(os.path.join(srcdir, 'dissect.html'),
+ os.path.join(outdir, 'dissect.html'))
+ shutil.copy(os.path.join(srcdir, 'edit.html'),
+ os.path.join(outdir, 'edit.html'))
+
+
+def generate_images(outdir, model, dataset, topk, levels,
+ segrunner, row_length=None, gap_pixels=5,
+ row_images=True, single_images=False, prefix='',
+ batch_size=100, num_workers=24):
+ '''
+ Creates an image strip file for every unit of every retained layer
+ of the model, in the format [outdir]/[layername]/[unitnum]-top.jpg.
+ Assumes that the indexes of topk refer to the indexes of dataset.
+ Limits each strip to the top row_length images.
+ '''
+ progress = default_progress()
+ needed_images = {}
+ if row_images is False:
+ row_length = 1
+ # Pass 1: needed_images lists all images that are topk for some unit.
+ for layer in topk:
+ topresult = topk[layer].result()[1].cpu()
+ for unit, row in enumerate(topresult):
+ for rank, imgnum in enumerate(row[:row_length]):
+ imgnum = imgnum.item()
+ if imgnum not in needed_images:
+ needed_images[imgnum] = []
+ needed_images[imgnum].append((layer, unit, rank))
+ levels = {k: v.cpu().numpy() for k, v in levels.items()}
+ row_length = len(row[:row_length])
+ needed_sample = FixedSubsetSampler(sorted(needed_images.keys()))
+ device = next(model.parameters()).device
+ segloader = torch.utils.data.DataLoader(dataset,
+ batch_size=batch_size, num_workers=num_workers,
+ pin_memory=(device.type == 'cuda'),
+ sampler=needed_sample)
+ vizgrid, maskgrid, origrid, seggrid = [{} for _ in range(4)]
+ # Pass 2: populate vizgrid with visualizations of top units.
+ pool = None
+ for i, batch in enumerate(
+ progress(segloader, desc='Making images')):
+ # Reverse transformation to get the image in byte form.
+ seg, _, byte_im, _ = segrunner.run_and_segment_batch(batch, model,
+ want_rgb=True)
+ torch_features = model.retained_features()
+ scale_offset = getattr(model, 'scale_offset', None)
+ if pool is None:
+ # Distribute the work across processes: create shared mmaps.
+ for layer, tf in torch_features.items():
+ [vizgrid[layer], maskgrid[layer], origrid[layer],
+ seggrid[layer]] = [
+ create_temp_mmap_grid((tf.shape[1],
+ byte_im.shape[1], row_length,
+ byte_im.shape[2] + gap_pixels, depth),
+ dtype='uint8',
+ fill=255)
+ for depth in [3, 4, 3, 3]]
+ # Pass those mmaps to worker processes.
+ pool = WorkerPool(worker=VisualizeImageWorker,
+ memmap_grid_info=[
+ {layer: (g.filename, g.shape, g.dtype)
+ for layer, g in grid.items()}
+ for grid in [vizgrid, maskgrid, origrid, seggrid]])
+ byte_im = byte_im.cpu().numpy()
+ numpy_seg = seg.cpu().numpy()
+ features = {}
+ for index in range(len(byte_im)):
+ imgnum = needed_sample.samples[index + i*segloader.batch_size]
+ for layer, unit, rank in needed_images[imgnum]:
+ if layer not in features:
+ features[layer] = torch_features[layer].cpu().numpy()
+ pool.add(layer, unit, rank,
+ byte_im[index],
+ features[layer][index, unit],
+ levels[layer][unit],
+ scale_offset[layer] if scale_offset else None,
+ numpy_seg[index])
+ pool.join()
+ # Pass 3: save image strips as [outdir]/[layer]/[unitnum]-[top/orig].jpg
+ pool = WorkerPool(worker=SaveImageWorker)
+ for layer, vg in progress(vizgrid.items(), desc='Saving images'):
+ os.makedirs(os.path.join(outdir, safe_dir_name(layer),
+ prefix + 'image'), exist_ok=True)
+ if single_images:
+ os.makedirs(os.path.join(outdir, safe_dir_name(layer),
+ prefix + 's-image'), exist_ok=True)
+ og, sg, mg = origrid[layer], seggrid[layer], maskgrid[layer]
+ for unit in progress(range(len(vg)), desc='Units'):
+ for suffix, grid in [('top.jpg', vg), ('orig.jpg', og),
+ ('seg.png', sg), ('mask.png', mg)]:
+ strip = grid[unit].reshape(
+ (grid.shape[1], grid.shape[2] * grid.shape[3],
+ grid.shape[4]))
+ if row_images:
+ filename = os.path.join(outdir, safe_dir_name(layer),
+ prefix + 'image', '%d-%s' % (unit, suffix))
+ pool.add(strip[:,:-gap_pixels,:].copy(), filename)
+ # Image.fromarray(strip[:,:-gap_pixels,:]).save(filename,
+ # optimize=True, quality=80)
+ if single_images:
+ single_filename = os.path.join(outdir, safe_dir_name(layer),
+ prefix + 's-image', '%d-%s' % (unit, suffix))
+ pool.add(strip[:,:strip.shape[1] // row_length
+ - gap_pixels,:].copy(), single_filename)
+ # Image.fromarray(strip[:,:strip.shape[1] // row_length
+ # - gap_pixels,:]).save(single_filename,
+ # optimize=True, quality=80)
+ pool.join()
+ # Delete the shared memory map files
+ clear_global_shared_files([g.filename
+ for grid in [vizgrid, maskgrid, origrid, seggrid]
+ for g in grid.values()])
+
+global_shared_files = {}
+def create_temp_mmap_grid(shape, dtype, fill):
+ dtype = numpy.dtype(dtype)
+ filename = os.path.join(tempfile.mkdtemp(), 'temp-%s-%s.mmap' %
+ ('x'.join('%d' % s for s in shape), dtype.name))
+ fid = open(filename, mode='w+b')
+ original = numpy.memmap(fid, dtype=dtype, mode='w+', shape=shape)
+ original.fid = fid
+ original[...] = fill
+ global_shared_files[filename] = original
+ return original
+
+def shared_temp_mmap_grid(filename, shape, dtype):
+ if filename not in global_shared_files:
+ global_shared_files[filename] = numpy.memmap(
+ filename, dtype=dtype, mode='r+', shape=shape)
+ return global_shared_files[filename]
+
+def clear_global_shared_files(filenames):
+ for fn in filenames:
+ if fn in global_shared_files:
+ del global_shared_files[fn]
+ try:
+ os.unlink(fn)
+ except OSError:
+ pass
+
+class VisualizeImageWorker(WorkerBase):
+ def setup(self, memmap_grid_info):
+ self.vizgrid, self.maskgrid, self.origrid, self.seggrid = [
+ {layer: shared_temp_mmap_grid(*info)
+ for layer, info in grid.items()}
+ for grid in memmap_grid_info]
+ def work(self, layer, unit, rank,
+ byte_im, acts, level, scale_offset, seg):
+ self.origrid[layer][unit,:,rank,:byte_im.shape[0],:] = byte_im
+ [self.vizgrid[layer][unit,:,rank,:byte_im.shape[0],:],
+ self.maskgrid[layer][unit,:,rank,:byte_im.shape[0],:]] = (
+ activation_visualization(
+ byte_im,
+ acts,
+ level,
+ scale_offset=scale_offset,
+ return_mask=True))
+ self.seggrid[layer][unit,:,rank,:byte_im.shape[0],:] = (
+ segment_visualization(seg, byte_im.shape[0:2]))
+
+class SaveImageWorker(WorkerBase):
+ def work(self, data, filename):
+ Image.fromarray(data).save(filename, optimize=True, quality=80)
+
+def score_tally_stats(label_category, tc, truth, cc, ic):
+ pred = cc[label_category]
+ total = tc[label_category][:, None]
+ truth = truth[:, None]
+ epsilon = 1e-20 # avoid division-by-zero
+ union = pred + truth - ic
+ iou = ic.double() / (union.double() + epsilon)
+ arr = torch.empty(size=(2, 2) + ic.shape, dtype=ic.dtype, device=ic.device)
+ arr[0, 0] = ic
+ arr[0, 1] = pred - ic
+ arr[1, 0] = truth - ic
+ arr[1, 1] = total - union
+ arr = arr.double() / total.double()
+ mi = mutual_information(arr)
+ je = joint_entropy(arr)
+ iqr = mi / je
+ iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0
+ return iou, iqr
+
+def collect_quantiles_and_topk(outdir, model, segloader,
+ segrunner, k=100, resolution=1024):
+ '''
+ Collects (estimated) quantile information and (exact) sorted top-K lists
+ for every channel in the retained layers of the model. Returns
+ a map of quantiles (one RunningQuantile for each layer) along with
+ a map of topk (one RunningTopK for each layer).
+ '''
+ device = next(model.parameters()).device
+ features = model.retained_features()
+ cached_quantiles = {
+ layer: load_quantile_if_present(os.path.join(outdir,
+ safe_dir_name(layer)), 'quantiles.npz',
+ device=torch.device('cpu'))
+ for layer in features }
+ cached_topks = {
+ layer: load_topk_if_present(os.path.join(outdir,
+ safe_dir_name(layer)), 'topk.npz',
+ device=torch.device('cpu'))
+ for layer in features }
+ if (all(value is not None for value in cached_quantiles.values()) and
+ all(value is not None for value in cached_topks.values())):
+ return cached_quantiles, cached_topks
+
+ layer_batch_size = 8
+ all_layers = list(features.keys())
+ layer_batches = [all_layers[i:i+layer_batch_size]
+ for i in range(0, len(all_layers), layer_batch_size)]
+
+ quantiles, topks = {}, {}
+ progress = default_progress()
+ for layer_batch in layer_batches:
+ for i, batch in enumerate(progress(segloader, desc='Quantiles')):
+ # We don't actually care about the model output.
+ model(batch[0].to(device))
+ features = model.retained_features()
+ # We care about the retained values
+ for key in layer_batch:
+ value = features[key]
+ if topks.get(key, None) is None:
+ topks[key] = RunningTopK(k)
+ if quantiles.get(key, None) is None:
+ quantiles[key] = RunningQuantile(resolution=resolution)
+ topvalue = value
+ if len(value.shape) > 2:
+ topvalue, _ = value.view(*(value.shape[:2] + (-1,))).max(2)
+ # Put the channel index last.
+ value = value.permute(
+ (0,) + tuple(range(2, len(value.shape))) + (1,)
+ ).contiguous().view(-1, value.shape[1])
+ quantiles[key].add(value)
+ topks[key].add(topvalue)
+ # Save GPU memory
+ for key in layer_batch:
+ quantiles[key].to_(torch.device('cpu'))
+ topks[key].to_(torch.device('cpu'))
+ for layer in quantiles:
+ save_state_dict(quantiles[layer],
+ os.path.join(outdir, safe_dir_name(layer), 'quantiles.npz'))
+ save_state_dict(topks[layer],
+ os.path.join(outdir, safe_dir_name(layer), 'topk.npz'))
+ return quantiles, topks
+
+def collect_bincounts(outdir, model, segloader, levels, segrunner):
+ '''
+ Returns label_counts, category_activation_counts, and intersection_counts,
+ across the data set, counting the pixels of intersection between upsampled,
+ thresholded model featuremaps, with segmentation classes in the segloader.
+
+ label_counts (independent of model): pixels across the data set that
+ are labeled with the given label.
+ category_activation_counts (one per layer): for each feature channel,
+ pixels across the dataset where the channel exceeds the level
+ threshold. There is one count per category: activations only
+ contribute to the categories for which any category labels are
+ present on the images.
+ intersection_counts (one per layer): for each feature channel and
+ label, pixels across the dataset where the channel exceeds
+ the level, and the labeled segmentation class is also present.
+
+ This is a performance-sensitive function. Best performance is
+ achieved with a counting scheme which assumes a segloader with
+ batch_size 1.
+ '''
+ # Load cached data if present
+ (iou_scores, iqr_scores,
+ total_counts, label_counts, category_activation_counts,
+ intersection_counts) = {}, {}, None, None, {}, {}
+ found_all = True
+ for layer in model.retained_features():
+ filename = os.path.join(outdir, safe_dir_name(layer), 'bincounts.npz')
+ if os.path.isfile(filename):
+ data = numpy.load(filename)
+ iou_scores[layer] = torch.from_numpy(data['iou_scores'])
+ iqr_scores[layer] = torch.from_numpy(data['iqr_scores'])
+ total_counts = torch.from_numpy(data['total_counts'])
+ label_counts = torch.from_numpy(data['label_counts'])
+ category_activation_counts[layer] = torch.from_numpy(
+ data['category_activation_counts'])
+ intersection_counts[layer] = torch.from_numpy(
+ data['intersection_counts'])
+ else:
+ found_all = False
+ if found_all:
+ return (iou_scores, iqr_scores,
+ total_counts, label_counts, category_activation_counts,
+ intersection_counts)
+
+ device = next(model.parameters()).device
+ labelcat, categories = segrunner.get_label_and_category_names()
+ label_category = [categories.index(c) if c in categories else 0
+ for l, c in labelcat]
+ num_labels, num_categories = (len(n) for n in [labelcat, categories])
+
+ # One-hot vector of category for each label
+ labelcat = torch.zeros(num_labels, num_categories,
+ dtype=torch.long, device=device)
+ labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category,
+ dtype='int64')).to(device)[:,None], 1)
+ # Running bincounts
+ # activation_counts = {}
+ assert segloader.batch_size == 1 # category_activation_counts needs this.
+ category_activation_counts = {}
+ intersection_counts = {}
+ label_counts = torch.zeros(num_labels, dtype=torch.long, device=device)
+ total_counts = torch.zeros(num_categories, dtype=torch.long, device=device)
+ progress = default_progress()
+ scale_offset_map = getattr(model, 'scale_offset', None)
+ upsample_grids = {}
+ # total_batch_categories = torch.zeros(
+ # labelcat.shape[1], dtype=torch.long, device=device)
+ for i, batch in enumerate(progress(segloader, desc='Bincounts')):
+ seg, batch_label_counts, _, imshape = segrunner.run_and_segment_batch(
+ batch, model, want_bincount=True, want_rgb=True)
+ bc = batch_label_counts.cpu()
+ batch_label_counts = batch_label_counts.to(device)
+ seg = seg.to(device)
+ features = model.retained_features()
+ # Accumulate bincounts and identify nonzeros
+ label_counts += batch_label_counts[0]
+ batch_labels = bc[0].nonzero()[:,0]
+ batch_categories = labelcat[batch_labels].max(0)[0]
+ total_counts += batch_categories * (
+ seg.shape[0] * seg.shape[2] * seg.shape[3])
+ for key, value in features.items():
+ if key not in upsample_grids:
+ upsample_grids[key] = upsample_grid(value.shape[2:],
+ seg.shape[2:], imshape,
+ scale_offset=scale_offset_map.get(key, None)
+ if scale_offset_map is not None else None,
+ dtype=value.dtype, device=value.device)
+ upsampled = torch.nn.functional.grid_sample(value,
+ upsample_grids[key], padding_mode='border')
+ amask = (upsampled > levels[key][None,:,None,None].to(
+ upsampled.device))
+ ac = amask.int().view(amask.shape[1], -1).sum(1)
+ # if key not in activation_counts:
+ # activation_counts[key] = ac
+ # else:
+ # activation_counts[key] += ac
+ # The fastest approach: sum over each label separately!
+ for label in batch_labels.tolist():
+ if label == 0:
+ continue # ignore the background label
+ imask = amask * ((seg == label).max(dim=1, keepdim=True)[0])
+ ic = imask.int().view(imask.shape[1], -1).sum(1)
+ if key not in intersection_counts:
+ intersection_counts[key] = torch.zeros(num_labels,
+ amask.shape[1], dtype=torch.long, device=device)
+ intersection_counts[key][label] += ic
+ # Count activations within images that have category labels.
+ # Note: This only makes sense with batch-size one
+ # total_batch_categories += batch_categories
+ cc = batch_categories[:,None] * ac[None,:]
+ if key not in category_activation_counts:
+ category_activation_counts[key] = cc
+ else:
+ category_activation_counts[key] += cc
+ iou_scores = {}
+ iqr_scores = {}
+ for k in intersection_counts:
+ iou_scores[k], iqr_scores[k] = score_tally_stats(
+ label_category, total_counts, label_counts,
+ category_activation_counts[k], intersection_counts[k])
+ for k in intersection_counts:
+ numpy.savez(os.path.join(outdir, safe_dir_name(k), 'bincounts.npz'),
+ iou_scores=iou_scores[k].cpu().numpy(),
+ iqr_scores=iqr_scores[k].cpu().numpy(),
+ total_counts=total_counts.cpu().numpy(),
+ label_counts=label_counts.cpu().numpy(),
+ category_activation_counts=category_activation_counts[k]
+ .cpu().numpy(),
+ intersection_counts=intersection_counts[k].cpu().numpy(),
+ levels=levels[k].cpu().numpy())
+ return (iou_scores, iqr_scores,
+ total_counts, label_counts, category_activation_counts,
+ intersection_counts)
+
+def collect_cond_quantiles(outdir, model, segloader, segrunner):
+ '''
+ Returns maxiou and maxiou_level across the data set, one per layer.
+
+ This is a performance-sensitive function. Best performance is
+ achieved with a counting scheme which assumes a segloader with
+ batch_size 1.
+ '''
+ device = next(model.parameters()).device
+ cached_cond_quantiles = {
+ layer: load_conditional_quantile_if_present(os.path.join(outdir,
+ safe_dir_name(layer)), 'cond_quantiles.npz') # on cpu
+ for layer in model.retained_features() }
+ label_fracs = load_npy_if_present(outdir, 'label_fracs.npy', 'cpu')
+ if label_fracs is not None and all(
+ value is not None for value in cached_cond_quantiles.values()):
+ return cached_cond_quantiles, label_fracs
+
+ labelcat, categories = segrunner.get_label_and_category_names()
+ label_category = [categories.index(c) if c in categories else 0
+ for l, c in labelcat]
+ num_labels, num_categories = (len(n) for n in [labelcat, categories])
+
+ # One-hot vector of category for each label
+ labelcat = torch.zeros(num_labels, num_categories,
+ dtype=torch.long, device=device)
+ labelcat.scatter_(1, torch.from_numpy(numpy.array(label_category,
+ dtype='int64')).to(device)[:,None], 1)
+ # Running maxiou
+ assert segloader.batch_size == 1 # category_activation_counts needs this.
+ conditional_quantiles = {}
+ label_counts = torch.zeros(num_labels, dtype=torch.long, device=device)
+ pixel_count = 0
+ progress = default_progress()
+ scale_offset_map = getattr(model, 'scale_offset', None)
+ upsample_grids = {}
+ common_conditions = set()
+ if label_fracs is None or label_fracs is 0:
+ for i, batch in enumerate(progress(segloader, desc='label fracs')):
+ seg, batch_label_counts, im, _ = segrunner.run_and_segment_batch(
+ batch, model, want_bincount=True, want_rgb=True)
+ batch_label_counts = batch_label_counts.to(device)
+ features = model.retained_features()
+ # Accumulate bincounts and identify nonzeros
+ label_counts += batch_label_counts[0]
+ pixel_count += seg.shape[2] * seg.shape[3]
+ label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None]
+ numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs)
+
+ skip_threshold = 1e-4
+ skip_labels = set(i.item()
+ for i in (label_fracs.view(-1) < skip_threshold).nonzero().view(-1))
+
+ for layer in progress(model.retained_features().keys(), desc='CQ layers'):
+ if cached_cond_quantiles.get(layer, None) is not None:
+ conditional_quantiles[layer] = cached_cond_quantiles[layer]
+ continue
+
+ for i, batch in enumerate(progress(segloader, desc='Condquant')):
+ seg, batch_label_counts, _, imshape = (
+ segrunner.run_and_segment_batch(
+ batch, model, want_bincount=True, want_rgb=True))
+ bc = batch_label_counts.cpu()
+ batch_label_counts = batch_label_counts.to(device)
+ features = model.retained_features()
+ # Accumulate bincounts and identify nonzeros
+ label_counts += batch_label_counts[0]
+ pixel_count += seg.shape[2] * seg.shape[3]
+ batch_labels = bc[0].nonzero()[:,0]
+ batch_categories = labelcat[batch_labels].max(0)[0]
+ cpu_seg = None
+ value = features[layer]
+ if layer not in upsample_grids:
+ upsample_grids[layer] = upsample_grid(value.shape[2:],
+ seg.shape[2:], imshape,
+ scale_offset=scale_offset_map.get(layer, None)
+ if scale_offset_map is not None else None,
+ dtype=value.dtype, device=value.device)
+ if layer not in conditional_quantiles:
+ conditional_quantiles[layer] = RunningConditionalQuantile(
+ resolution=2048)
+ upsampled = torch.nn.functional.grid_sample(value,
+ upsample_grids[layer], padding_mode='border').view(
+ value.shape[1], -1)
+ conditional_quantiles[layer].add(('all',), upsampled.t())
+ cpu_upsampled = None
+ for label in batch_labels.tolist():
+ if label in skip_labels:
+ continue
+ label_key = ('label', label)
+ if label_key in common_conditions:
+ imask = (seg == label).max(dim=1)[0].view(-1)
+ intersected = upsampled[:, imask]
+ conditional_quantiles[layer].add(('label', label),
+ intersected.t())
+ else:
+ if cpu_seg is None:
+ cpu_seg = seg.cpu()
+ if cpu_upsampled is None:
+ cpu_upsampled = upsampled.cpu()
+ imask = (cpu_seg == label).max(dim=1)[0].view(-1)
+ intersected = cpu_upsampled[:, imask]
+ conditional_quantiles[layer].add(('label', label),
+ intersected.t())
+ if num_categories > 1:
+ for cat in batch_categories.nonzero()[:,0]:
+ conditional_quantiles[layer].add(('cat', cat.item()),
+ upsampled.t())
+ # Move the most common conditions to the GPU.
+ if i and not i & (i - 1): # if i is a power of 2:
+ cq = conditional_quantiles[layer]
+ common_conditions = set(cq.most_common_conditions(64))
+ cq.to_('cpu', [k for k in cq.running_quantiles.keys()
+ if k not in common_conditions])
+ # When a layer is done, get it off the GPU
+ conditional_quantiles[layer].to_('cpu')
+
+ label_fracs = (label_counts.cpu().float() / pixel_count)[:, None, None]
+
+ for cq in conditional_quantiles.values():
+ cq.to_('cpu')
+
+ for layer in conditional_quantiles:
+ save_state_dict(conditional_quantiles[layer],
+ os.path.join(outdir, safe_dir_name(layer), 'cond_quantiles.npz'))
+ numpy.save(os.path.join(outdir, 'label_fracs.npy'), label_fracs)
+
+ return conditional_quantiles, label_fracs
+
+
+def collect_maxiou(outdir, model, segloader, segrunner):
+ '''
+ Returns maxiou and maxiou_level across the data set, one per layer.
+
+ This is a performance-sensitive function. Best performance is
+ achieved with a counting scheme which assumes a segloader with
+ batch_size 1.
+ '''
+ device = next(model.parameters()).device
+ conditional_quantiles, label_fracs = collect_cond_quantiles(
+ outdir, model, segloader, segrunner)
+
+ labelcat, categories = segrunner.get_label_and_category_names()
+ label_category = [categories.index(c) if c in categories else 0
+ for l, c in labelcat]
+ num_labels, num_categories = (len(n) for n in [labelcat, categories])
+
+ label_list = [('label', i) for i in range(num_labels)]
+ category_list = [('all',)] if num_categories <= 1 else (
+ [('cat', i) for i in range(num_categories)])
+ max_iou, max_iou_level, max_iou_quantile = {}, {}, {}
+ fracs = torch.logspace(-3, 0, 100)
+ progress = default_progress()
+ for layer, cq in progress(conditional_quantiles.items(), desc='Maxiou'):
+ levels = cq.conditional(('all',)).quantiles(1 - fracs)
+ denoms = 1 - cq.collected_normalize(category_list, levels)
+ isects = (1 - cq.collected_normalize(label_list, levels)) * label_fracs
+ unions = label_fracs + denoms[label_category, :, :] - isects
+ iou = isects / unions
+ # TODO: erase any for which threshold is bad
+ max_iou[layer], level_bucket = iou.max(2)
+ max_iou_level[layer] = levels[
+ torch.arange(levels.shape[0])[None,:], level_bucket]
+ max_iou_quantile[layer] = fracs[level_bucket]
+ for layer in model.retained_features():
+ numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'max_iou.npz'),
+ max_iou=max_iou[layer].cpu().numpy(),
+ max_iou_level=max_iou_level[layer].cpu().numpy(),
+ max_iou_quantile=max_iou_quantile[layer].cpu().numpy())
+ return (max_iou, max_iou_level, max_iou_quantile)
+
+def collect_iqr(outdir, model, segloader, segrunner):
+ '''
+ Returns iqr and iqr_level.
+
+ This is a performance-sensitive function. Best performance is
+ achieved with a counting scheme which assumes a segloader with
+ batch_size 1.
+ '''
+ max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou = {}, {}, {}, {}
+ max_iqr_agreement = {}
+ found_all = True
+ for layer in model.retained_features():
+ filename = os.path.join(outdir, safe_dir_name(layer), 'iqr.npz')
+ if os.path.isfile(filename):
+ data = numpy.load(filename)
+ max_iqr[layer] = torch.from_numpy(data['max_iqr'])
+ max_iqr_level[layer] = torch.from_numpy(data['max_iqr_level'])
+ max_iqr_quantile[layer] = torch.from_numpy(data['max_iqr_quantile'])
+ max_iqr_iou[layer] = torch.from_numpy(data['max_iqr_iou'])
+ max_iqr_agreement[layer] = torch.from_numpy(
+ data['max_iqr_agreement'])
+ else:
+ found_all = False
+ if found_all:
+ return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou,
+ max_iqr_agreement)
+
+
+ device = next(model.parameters()).device
+ conditional_quantiles, label_fracs = collect_cond_quantiles(
+ outdir, model, segloader, segrunner)
+
+ labelcat, categories = segrunner.get_label_and_category_names()
+ label_category = [categories.index(c) if c in categories else 0
+ for l, c in labelcat]
+ num_labels, num_categories = (len(n) for n in [labelcat, categories])
+
+ label_list = [('label', i) for i in range(num_labels)]
+ category_list = [('all',)] if num_categories <= 1 else (
+ [('cat', i) for i in range(num_categories)])
+ full_mi, full_je, full_iqr = {}, {}, {}
+ fracs = torch.logspace(-3, 0, 100)
+ progress = default_progress()
+ for layer, cq in progress(conditional_quantiles.items(), desc='IQR'):
+ levels = cq.conditional(('all',)).quantiles(1 - fracs)
+ truth = label_fracs.to(device)
+ preds = (1 - cq.collected_normalize(category_list, levels)
+ )[label_category, :, :].to(device)
+ cond_isects = 1 - cq.collected_normalize(label_list, levels).to(device)
+ isects = cond_isects * truth
+ unions = truth + preds - isects
+ arr = torch.empty(size=(2, 2) + isects.shape, dtype=isects.dtype,
+ device=device)
+ arr[0, 0] = isects
+ arr[0, 1] = preds - isects
+ arr[1, 0] = truth - isects
+ arr[1, 1] = 1 - unions
+ arr.clamp_(0, 1)
+ mi = mutual_information(arr)
+ mi[:,:,-1] = 0 # at the 1.0 quantile should be no MI.
+ # Don't trust mi when less than label_frac is less than 1e-3,
+ # because our samples are too small.
+ mi[label_fracs.view(-1) < 1e-3, :, :] = 0
+ je = joint_entropy(arr)
+ iqr = mi / je
+ iqr[torch.isnan(iqr)] = 0 # Zero out any 0/0
+ full_mi[layer] = mi.cpu()
+ full_je[layer] = je.cpu()
+ full_iqr[layer] = iqr.cpu()
+ del mi, je
+ agreement = isects + arr[1, 1]
+ # When optimizing, maximize only over those pairs where the
+ # unit is positively correlated with the label, and where the
+ # threshold level is positive
+ positive_iqr = iqr
+ positive_iqr[agreement <= 0.8] = 0
+ positive_iqr[(levels <= 0.0)[None, :, :].expand(positive_iqr.shape)] = 0
+ # TODO: erase any for which threshold is bad
+ maxiqr, level_bucket = positive_iqr.max(2)
+ max_iqr[layer] = maxiqr.cpu()
+ max_iqr_level[layer] = levels.to(device)[
+ torch.arange(levels.shape[0])[None,:], level_bucket].cpu()
+ max_iqr_quantile[layer] = fracs.to(device)[level_bucket].cpu()
+ max_iqr_agreement[layer] = agreement[
+ torch.arange(agreement.shape[0])[:, None],
+ torch.arange(agreement.shape[1])[None, :],
+ level_bucket].cpu()
+
+ # Compute the iou that goes with each maximized iqr
+ matching_iou = (isects[
+ torch.arange(isects.shape[0])[:, None],
+ torch.arange(isects.shape[1])[None, :],
+ level_bucket] /
+ unions[
+ torch.arange(unions.shape[0])[:, None],
+ torch.arange(unions.shape[1])[None, :],
+ level_bucket])
+ matching_iou[torch.isnan(matching_iou)] = 0
+ max_iqr_iou[layer] = matching_iou.cpu()
+ for layer in model.retained_features():
+ numpy.savez(os.path.join(outdir, safe_dir_name(layer), 'iqr.npz'),
+ max_iqr=max_iqr[layer].cpu().numpy(),
+ max_iqr_level=max_iqr_level[layer].cpu().numpy(),
+ max_iqr_quantile=max_iqr_quantile[layer].cpu().numpy(),
+ max_iqr_iou=max_iqr_iou[layer].cpu().numpy(),
+ max_iqr_agreement=max_iqr_agreement[layer].cpu().numpy(),
+ full_mi=full_mi[layer].cpu().numpy(),
+ full_je=full_je[layer].cpu().numpy(),
+ full_iqr=full_iqr[layer].cpu().numpy())
+ return (max_iqr, max_iqr_level, max_iqr_quantile, max_iqr_iou,
+ max_iqr_agreement)
+
+def mutual_information(arr):
+ total = 0
+ for j in range(arr.shape[0]):
+ for k in range(arr.shape[1]):
+ joint = arr[j,k]
+ ind = arr[j,:].sum(dim=0) * arr[:,k].sum(dim=0)
+ term = joint * (joint / ind).log()
+ term[torch.isnan(term)] = 0
+ total += term
+ return total.clamp_(0)
+
+def joint_entropy(arr):
+ total = 0
+ for j in range(arr.shape[0]):
+ for k in range(arr.shape[1]):
+ joint = arr[j,k]
+ term = joint * joint.log()
+ term[torch.isnan(term)] = 0
+ total += term
+ return (-total).clamp_(0)
+
+def information_quality_ratio(arr):
+ iqr = mutual_information(arr) / joint_entropy(arr)
+ iqr[torch.isnan(iqr)] = 0
+ return iqr
+
+def collect_covariance(outdir, model, segloader, segrunner):
+ '''
+ Returns label_mean, label_variance, unit_mean, unit_variance,
+ and cross_covariance across the data set.
+
+ label_mean, label_variance (independent of model):
+ treating the label as a one-hot, each label's mean and variance.
+ unit_mean, unit_variance (one per layer): for each feature channel,
+ the mean and variance of the activations in that channel.
+ cross_covariance (one per layer): the cross covariance between the
+ labels and the units in the layer.
+ '''
+ device = next(model.parameters()).device
+ cached_covariance = {
+ layer: load_covariance_if_present(os.path.join(outdir,
+ safe_dir_name(layer)), 'covariance.npz', device=device)
+ for layer in model.retained_features() }
+ if all(value is not None for value in cached_covariance.values()):
+ return cached_covariance
+ labelcat, categories = segrunner.get_label_and_category_names()
+ label_category = [categories.index(c) if c in categories else 0
+ for l, c in labelcat]
+ num_labels, num_categories = (len(n) for n in [labelcat, categories])
+
+ # Running covariance
+ cov = {}
+ progress = default_progress()
+ scale_offset_map = getattr(model, 'scale_offset', None)
+ upsample_grids = {}
+ for i, batch in enumerate(progress(segloader, desc='Covariance')):
+ seg, _, _, imshape = segrunner.run_and_segment_batch(batch, model,
+ want_rgb=True)
+ features = model.retained_features()
+ ohfeats = multilabel_onehot(seg, num_labels, ignore_index=0)
+ # Accumulate bincounts and identify nonzeros
+ for key, value in features.items():
+ if key not in upsample_grids:
+ upsample_grids[key] = upsample_grid(value.shape[2:],
+ seg.shape[2:], imshape,
+ scale_offset=scale_offset_map.get(key, None)
+ if scale_offset_map is not None else None,
+ dtype=value.dtype, device=value.device)
+ upsampled = torch.nn.functional.grid_sample(value,
+ upsample_grids[key].expand(
+ (value.shape[0],) + upsample_grids[key].shape[1:]),
+ padding_mode='border')
+ if key not in cov:
+ cov[key] = RunningCrossCovariance()
+ cov[key].add(upsampled, ohfeats)
+ for layer in cov:
+ save_state_dict(cov[layer],
+ os.path.join(outdir, safe_dir_name(layer), 'covariance.npz'))
+ return cov
+
+def multilabel_onehot(labels, num_labels, dtype=None, ignore_index=None):
+ '''
+ Converts a multilabel tensor into a onehot tensor.
+
+ The input labels is a tensor of shape (samples, multilabels, y, x).
+ The output is a tensor of shape (samples, num_labels, y, x).
+ If ignore_index is specified, labels with that index are ignored.
+ Each x in labels should be 0 <= x < num_labels, or x == ignore_index.
+ '''
+ assert ignore_index is None or ignore_index <= 0
+ if dtype is None:
+ dtype = torch.float
+ device = labels.device
+ chans = num_labels + (-ignore_index if ignore_index else 0)
+ outshape = (labels.shape[0], chans) + labels.shape[2:]
+ result = torch.zeros(outshape, device=device, dtype=dtype)
+ if ignore_index and ignore_index < 0:
+ labels = labels + (-ignore_index)
+ result.scatter_(1, labels, 1)
+ if ignore_index and ignore_index < 0:
+ result = result[:, -ignore_index:]
+ elif ignore_index is not None:
+ result[:, ignore_index] = 0
+ return result
+
+def load_npy_if_present(outdir, filename, device):
+ filepath = os.path.join(outdir, filename)
+ if os.path.isfile(filepath):
+ data = numpy.load(filepath)
+ return torch.from_numpy(data).to(device)
+ return 0
+
+def load_npz_if_present(outdir, filename, varnames, device):
+ filepath = os.path.join(outdir, filename)
+ if os.path.isfile(filepath):
+ data = numpy.load(filepath)
+ numpy_result = [data[n] for n in varnames]
+ return tuple(torch.from_numpy(data).to(device) for data in numpy_result)
+ return None
+
+def load_quantile_if_present(outdir, filename, device):
+ filepath = os.path.join(outdir, filename)
+ if os.path.isfile(filepath):
+ data = numpy.load(filepath)
+ result = RunningQuantile(state=data)
+ result.to_(device)
+ return result
+ return None
+
+def load_conditional_quantile_if_present(outdir, filename):
+ filepath = os.path.join(outdir, filename)
+ if os.path.isfile(filepath):
+ data = numpy.load(filepath)
+ result = RunningConditionalQuantile(state=data)
+ return result
+ return None
+
+def load_topk_if_present(outdir, filename, device):
+ filepath = os.path.join(outdir, filename)
+ if os.path.isfile(filepath):
+ data = numpy.load(filepath)
+ result = RunningTopK(state=data)
+ result.to_(device)
+ return result
+ return None
+
+def load_covariance_if_present(outdir, filename, device):
+ filepath = os.path.join(outdir, filename)
+ if os.path.isfile(filepath):
+ data = numpy.load(filepath)
+ result = RunningCrossCovariance(state=data)
+ result.to_(device)
+ return result
+ return None
+
+def save_state_dict(obj, filepath):
+ dirname = os.path.dirname(filepath)
+ os.makedirs(dirname, exist_ok=True)
+ dic = obj.state_dict()
+ numpy.savez(filepath, **dic)
+
+def upsample_grid(data_shape, target_shape, input_shape=None,
+ scale_offset=None, dtype=torch.float, device=None):
+ '''Prepares a grid to use with grid_sample to upsample a batch of
+ features in data_shape to the target_shape. Can use scale_offset
+ and input_shape to center the grid in a nondefault way: scale_offset
+ maps feature pixels to input_shape pixels, and it is assumed that
+ the target_shape is a uniform downsampling of input_shape.'''
+ # Default is that nothing is resized.
+ if target_shape is None:
+ target_shape = data_shape
+ # Make a default scale_offset to fill the image if there isn't one
+ if scale_offset is None:
+ scale = tuple(float(ts) / ds
+ for ts, ds in zip(target_shape, data_shape))
+ offset = tuple(0.5 * s - 0.5 for s in scale)
+ else:
+ scale, offset = (v for v in zip(*scale_offset))
+ # Handle downsampling for different input vs target shape.
+ if input_shape is not None:
+ scale = tuple(s * (ts - 1) / (ns - 1)
+ for s, ns, ts in zip(scale, input_shape, target_shape))
+ offset = tuple(o * (ts - 1) / (ns - 1)
+ for o, ns, ts in zip(offset, input_shape, target_shape))
+ # Pytorch needs target coordinates in terms of source coordinates [-1..1]
+ ty, tx = (((torch.arange(ts, dtype=dtype, device=device) - o)
+ * (2 / (s * (ss - 1))) - 1)
+ for ts, ss, s, o, in zip(target_shape, data_shape, scale, offset))
+ # Whoa, note that grid_sample reverses the order y, x -> x, y.
+ grid = torch.stack(
+ (tx[None,:].expand(target_shape), ty[:,None].expand(target_shape)),2
+ )[None,:,:,:].expand((1, target_shape[0], target_shape[1], 2))
+ return grid
+
+def safe_dir_name(filename):
+ keepcharacters = (' ','.','_','-')
+ return ''.join(c
+ for c in filename if c.isalnum() or c in keepcharacters).rstrip()
+
+bargraph_palette = [
+ ('#4B4CBF', '#B6B6F2'),
+ ('#55B05B', '#B6F2BA'),
+ ('#50BDAC', '#A5E5DB'),
+ ('#81C679', '#C0FF9B'),
+ ('#F0883B', '#F2CFB6'),
+ ('#D4CF24', '#F2F1B6'),
+ ('#D92E2B', '#F2B6B6'),
+ ('#AB6BC6', '#CFAAFF'),
+]
+
+def make_svg_bargraph(labels, heights, categories,
+ barheight=100, barwidth=12, show_labels=True, filename=None):
+ # if len(labels) == 0:
+ # return # Nothing to do
+ unitheight = float(barheight) / max(max(heights, default=1), 1)
+ textheight = barheight if show_labels else 0
+ labelsize = float(barwidth)
+ gap = float(barwidth) / 4
+ textsize = barwidth + gap
+ rollup = max(heights, default=1)
+ textmargin = float(labelsize) * 2 / 3
+ leftmargin = 32
+ rightmargin = 8
+ svgwidth = len(heights) * (barwidth + gap) + 2 * leftmargin + rightmargin
+ svgheight = barheight + textheight
+
+ # create an SVG XML element
+ svg = et.Element('svg', width=str(svgwidth), height=str(svgheight),
+ version='1.1', xmlns='http://www.w3.org/2000/svg')
+
+ # Draw the bar graph
+ basey = svgheight - textheight
+ x = leftmargin
+ # Add units scale on left
+ if len(heights):
+ for h in [1, (max(heights) + 1) // 2, max(heights)]:
+ et.SubElement(svg, 'text', x='0', y='0',
+ style=('font-family:sans-serif;font-size:%dpx;' +
+ 'text-anchor:end;alignment-baseline:hanging;' +
+ 'transform:translate(%dpx, %dpx);') %
+ (textsize, x - gap, basey - h * unitheight)).text = str(h)
+ et.SubElement(svg, 'text', x='0', y='0',
+ style=('font-family:sans-serif;font-size:%dpx;' +
+ 'text-anchor:middle;' +
+ 'transform:translate(%dpx, %dpx) rotate(-90deg)') %
+ (textsize, x - gap - textsize, basey - h * unitheight / 2)
+ ).text = 'units'
+ # Draw big category background rectangles
+ for catindex, (cat, catcount) in enumerate(categories):
+ if not catcount:
+ continue
+ et.SubElement(svg, 'rect', x=str(x), y=str(basey - rollup * unitheight),
+ width=(str((barwidth + gap) * catcount - gap)),
+ height = str(rollup*unitheight),
+ fill=bargraph_palette[catindex % len(bargraph_palette)][1])
+ x += (barwidth + gap) * catcount
+ # Draw small bars as well as 45degree text labels
+ x = leftmargin
+ catindex = -1
+ catcount = 0
+ for label, height in zip(labels, heights):
+ while not catcount and catindex <= len(categories):
+ catindex += 1
+ catcount = categories[catindex][1]
+ color = bargraph_palette[catindex % len(bargraph_palette)][0]
+ et.SubElement(svg, 'rect', x=str(x), y=str(basey-(height * unitheight)),
+ width=str(barwidth), height=str(height * unitheight),
+ fill=color)
+ x += barwidth
+ if show_labels:
+ et.SubElement(svg, 'text', x='0', y='0',
+ style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+
+ 'transform:translate(%dpx, %dpx) rotate(-45deg);') %
+ (labelsize, x, basey + textmargin)).text = readable(label)
+ x += gap
+ catcount -= 1
+ # Text labels for each category
+ x = leftmargin
+ for cat, catcount in categories:
+ if not catcount:
+ continue
+ et.SubElement(svg, 'text', x='0', y='0',
+ style=('font-family:sans-serif;font-size:%dpx;text-anchor:end;'+
+ 'transform:translate(%dpx, %dpx) rotate(-90deg);') %
+ (textsize, x + (barwidth + gap) * catcount - gap,
+ basey - rollup * unitheight + gap)).text = '%d %s' % (
+ catcount, readable(cat + ('s' if catcount != 1 else '')))
+ x += (barwidth + gap) * catcount
+ # Output - this is the bare svg.
+ result = et.tostring(svg)
+ if filename:
+ f = open(filename, 'wb')
+ # When writing to a file a special header is needed.
+ f.write(''.join([
+ '\n',
+ '\n']
+ ).encode('utf-8'))
+ f.write(result)
+ f.close()
+ return result
+
+readable_replacements = [(re.compile(r[0]), r[1]) for r in [
+ (r'-[sc]$', ''),
+ (r'_', ' '),
+ ]]
+
+def readable(label):
+ for pattern, subst in readable_replacements:
+ label= re.sub(pattern, subst, label)
+ return label
+
+def reverse_normalize_from_transform(transform):
+ '''
+ Crawl around the transforms attached to a dataset looking for a
+ Normalize transform, and return it a corresponding ReverseNormalize,
+ or None if no normalization is found.
+ '''
+ if isinstance(transform, torchvision.transforms.Normalize):
+ return ReverseNormalize(transform.mean, transform.std)
+ t = getattr(transform, 'transform', None)
+ if t is not None:
+ return reverse_normalize_from_transform(t)
+ transforms = getattr(transform, 'transforms', None)
+ if transforms is not None:
+ for t in reversed(transforms):
+ result = reverse_normalize_from_transform(t)
+ if result is not None:
+ return result
+ return None
+
+class ReverseNormalize:
+ '''
+ Applies the reverse of torchvision.transforms.Normalize.
+ '''
+ def __init__(self, mean, stdev):
+ mean = numpy.array(mean)
+ stdev = numpy.array(stdev)
+ self.mean = torch.from_numpy(mean)[None,:,None,None].float()
+ self.stdev = torch.from_numpy(stdev)[None,:,None,None].float()
+ def __call__(self, data):
+ device = data.device
+ return data.mul(self.stdev.to(device)).add_(self.mean.to(device))
+
+class ImageOnlySegRunner:
+ def __init__(self, dataset, recover_image=None):
+ if recover_image is None:
+ recover_image = reverse_normalize_from_transform(dataset)
+ self.recover_image = recover_image
+ self.dataset = dataset
+ def get_label_and_category_names(self):
+ return [('-', '-')], ['-']
+ def run_and_segment_batch(self, batch, model,
+ want_bincount=False, want_rgb=False):
+ [im] = batch
+ device = next(model.parameters()).device
+ if want_rgb:
+ rgb = self.recover_image(im.clone()
+ ).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte()
+ else:
+ rgb = None
+ # Stubs for seg and bc
+ seg = torch.zeros(im.shape[0], 1, 1, 1, dtype=torch.long)
+ bc = torch.ones(im.shape[0], 1, dtype=torch.long)
+ # Run the model.
+ model(im.to(device))
+ return seg, bc, rgb, im.shape[2:]
+
+class ClassifierSegRunner:
+ def __init__(self, dataset, recover_image=None):
+ # The dataset contains explicit segmentations
+ if recover_image is None:
+ recover_image = reverse_normalize_from_transform(dataset)
+ self.recover_image = recover_image
+ self.dataset = dataset
+ def get_label_and_category_names(self):
+ catnames = self.dataset.categories
+ label_and_cat_names = [(readable(label),
+ catnames[self.dataset.label_category[i]])
+ for i, label in enumerate(self.dataset.labels)]
+ return label_and_cat_names, catnames
+ def run_and_segment_batch(self, batch, model,
+ want_bincount=False, want_rgb=False):
+ '''
+ Runs the dissected model on one batch of the dataset, and
+ returns a multilabel semantic segmentation for the data.
+ Given a batch of size (n, c, y, x) the segmentation should
+ be a (long integer) tensor of size (n, d, y//r, x//r) where
+ d is the maximum number of simultaneous labels given to a pixel,
+ and where r is some (optional) resolution reduction factor.
+ In the segmentation returned, the label `0` is reserved for
+ the background "no-label".
+
+ In addition to the segmentation, bc, rgb, and shape are returned
+ where bc is a per-image bincount counting returned label pixels,
+ rgb is a viewable (n, y, x, rgb) byte image tensor for the data
+ for visualizations (reversing normalizations, for example), and
+ shape is the (y, x) size of the data. If want_bincount or
+ want_rgb are False, those return values may be None.
+ '''
+ im, seg, bc = batch
+ device = next(model.parameters()).device
+ if want_rgb:
+ rgb = self.recover_image(im.clone()
+ ).permute(0, 2, 3, 1).mul_(255).clamp(0, 255).byte()
+ else:
+ rgb = None
+ # Run the model.
+ model(im.to(device))
+ return seg, bc, rgb, im.shape[2:]
+
+class GeneratorSegRunner:
+ def __init__(self, segmenter):
+ # The segmentations are given by an algorithm
+ if segmenter is None:
+ segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')
+ self.segmenter = segmenter
+ self.num_classes = len(segmenter.get_label_and_category_names()[0])
+ def get_label_and_category_names(self):
+ return self.segmenter.get_label_and_category_names()
+ def run_and_segment_batch(self, batch, model,
+ want_bincount=False, want_rgb=False):
+ '''
+ Runs the dissected model on one batch of the dataset, and
+ returns a multilabel semantic segmentation for the data.
+ Given a batch of size (n, c, y, x) the segmentation should
+ be a (long integer) tensor of size (n, d, y//r, x//r) where
+ d is the maximum number of simultaneous labels given to a pixel,
+ and where r is some (optional) resolution reduction factor.
+ In the segmentation returned, the label `0` is reserved for
+ the background "no-label".
+
+ In addition to the segmentation, bc, rgb, and shape are returned
+ where bc is a per-image bincount counting returned label pixels,
+ rgb is a viewable (n, y, x, rgb) byte image tensor for the data
+ for visualizations (reversing normalizations, for example), and
+ shape is the (y, x) size of the data. If want_bincount or
+ want_rgb are False, those return values may be None.
+ '''
+ device = next(model.parameters()).device
+ z_batch = batch[0]
+ tensor_images = model(z_batch.to(device))
+ seg = self.segmenter.segment_batch(tensor_images, downsample=2)
+ if want_bincount:
+ index = torch.arange(z_batch.shape[0],
+ dtype=torch.long, device=device)
+ bc = (seg + index[:, None, None, None] * self.num_classes).view(-1
+ ).bincount(minlength=z_batch.shape[0] * self.num_classes)
+ bc = bc.view(z_batch.shape[0], self.num_classes)
+ else:
+ bc = None
+ if want_rgb:
+ images = ((tensor_images + 1) / 2 * 255)
+ rgb = images.permute(0, 2, 3, 1).clamp(0, 255).byte()
+ else:
+ rgb = None
+ return seg, bc, rgb, tensor_images.shape[2:]
diff --git a/netdissect/easydict.py b/netdissect/easydict.py
new file mode 100644
index 0000000000000000000000000000000000000000..0188f524b87eef75c175772ff262b93b47919ba7
--- /dev/null
+++ b/netdissect/easydict.py
@@ -0,0 +1,126 @@
+'''
+From https://github.com/makinacorpus/easydict.
+'''
+
+class EasyDict(dict):
+ """
+ Get attributes
+
+ >>> d = EasyDict({'foo':3})
+ >>> d['foo']
+ 3
+ >>> d.foo
+ 3
+ >>> d.bar
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'EasyDict' object has no attribute 'bar'
+
+ Works recursively
+
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
+ >>> isinstance(d.bar, dict)
+ True
+ >>> d.bar.x
+ 1
+
+ Bullet-proof
+
+ >>> EasyDict({})
+ {}
+ >>> EasyDict(d={})
+ {}
+ >>> EasyDict(None)
+ {}
+ >>> d = {'a': 1}
+ >>> EasyDict(**d)
+ {'a': 1}
+
+ Set attributes
+
+ >>> d = EasyDict()
+ >>> d.foo = 3
+ >>> d.foo
+ 3
+ >>> d.bar = {'prop': 'value'}
+ >>> d.bar.prop
+ 'value'
+ >>> d
+ {'foo': 3, 'bar': {'prop': 'value'}}
+ >>> d.bar.prop = 'newer'
+ >>> d.bar.prop
+ 'newer'
+
+
+ Values extraction
+
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
+ >>> isinstance(d.bar, list)
+ True
+ >>> from operator import attrgetter
+ >>> map(attrgetter('x'), d.bar)
+ [1, 3]
+ >>> map(attrgetter('y'), d.bar)
+ [2, 4]
+ >>> d = EasyDict()
+ >>> d.keys()
+ []
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
+ >>> d.foo
+ 3
+ >>> d.bar.x
+ 1
+
+ Still like a dict though
+
+ >>> o = EasyDict({'clean':True})
+ >>> o.items()
+ [('clean', True)]
+
+ And like a class
+
+ >>> class Flower(EasyDict):
+ ... power = 1
+ ...
+ >>> f = Flower()
+ >>> f.power
+ 1
+ >>> f = Flower({'height': 12})
+ >>> f.height
+ 12
+ >>> f['power']
+ 1
+ >>> sorted(f.keys())
+ ['height', 'power']
+ """
+ def __init__(self, d=None, **kwargs):
+ if d is None:
+ d = {}
+ if kwargs:
+ d.update(**kwargs)
+ for k, v in d.items():
+ setattr(self, k, v)
+ # Class attributes
+ for k in self.__class__.__dict__.keys():
+ if not (k.startswith('__') and k.endswith('__')):
+ setattr(self, k, getattr(self, k))
+
+ def __setattr__(self, name, value):
+ if isinstance(value, (list, tuple)):
+ value = [self.__class__(x)
+ if isinstance(x, dict) else x for x in value]
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
+ value = self.__class__(value)
+ super(EasyDict, self).__setattr__(name, value)
+ super(EasyDict, self).__setitem__(name, value)
+
+ __setitem__ = __setattr__
+
+def load_json(filename):
+ import json
+ with open(filename) as f:
+ return EasyDict(json.load(f))
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/netdissect/edit.html b/netdissect/edit.html
new file mode 100644
index 0000000000000000000000000000000000000000..9aac30bb08171c4c58eb936f9ba382e85a184803
--- /dev/null
+++ b/netdissect/edit.html
@@ -0,0 +1,805 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
{{urec.iou_label}}
+
{{urec.layer}}{{urec.unit}}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Seeds to generate
+
+To transfer activations from one pixel to another (1) click on a source pixel
+on the left image and (2) click on a target pixel on a right image,
+then (3) choose a set of units to insert in the palette.
+
+
#{{ ex.id }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/netdissect/evalablate.py b/netdissect/evalablate.py
new file mode 100644
index 0000000000000000000000000000000000000000..2079ffdb303b288df77678109f701e40fdf5779b
--- /dev/null
+++ b/netdissect/evalablate.py
@@ -0,0 +1,248 @@
+import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL
+from torchvision import transforms
+from torch.utils.data import TensorDataset
+from netdissect.progress import default_progress, post_progress, desc_progress
+from netdissect.progress import verbose_progress, print_progress
+from netdissect.nethook import edit_layers
+from netdissect.zdataset import standard_z_sample
+from netdissect.autoeval import autoimport_eval
+from netdissect.easydict import EasyDict
+from netdissect.modelconfig import create_instrumented_model
+
+help_epilog = '''\
+Example:
+
+python -m netdissect.evalablate \
+ --segmenter "netdissect.segmenter.UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')" \
+ --model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \
+ --outdir dissect/dissectdir \
+ --classes mirror coffeetable tree \
+ --layers layer4 \
+ --size 1000
+
+Output layout:
+dissectdir/layer5/ablation/mirror-iqr.json
+{ class: "mirror",
+ classnum: 43,
+ pixel_total: 41342300,
+ class_pixels: 1234531,
+ layer: "layer5",
+ ranking: "mirror-iqr",
+ ablation_units: [341, 23, 12, 142, 83, ...]
+ ablation_pixels: [143242, 132344, 429931, ...]
+}
+
+'''
+
+def main():
+ # Training settings
+ def strpair(arg):
+ p = tuple(arg.split(':'))
+ if len(p) == 1:
+ p = p + p
+ return p
+
+ parser = argparse.ArgumentParser(description='Ablation eval',
+ epilog=textwrap.dedent(help_epilog),
+ formatter_class=argparse.RawDescriptionHelpFormatter)
+ parser.add_argument('--model', type=str, default=None,
+ help='constructor for the model to test')
+ parser.add_argument('--pthfile', type=str, default=None,
+ help='filename of .pth file for the model')
+ parser.add_argument('--outdir', type=str, default='dissect', required=True,
+ help='directory for dissection output')
+ parser.add_argument('--layers', type=strpair, nargs='+',
+ help='space-separated list of layer names to edit' +
+ ', in the form layername[:reportedname]')
+ parser.add_argument('--classes', type=str, nargs='+',
+ help='space-separated list of class names to ablate')
+ parser.add_argument('--metric', type=str, default='iou',
+ help='ordering metric for selecting units')
+ parser.add_argument('--unitcount', type=int, default=30,
+ help='number of units to ablate')
+ parser.add_argument('--segmenter', type=str,
+ help='directory containing segmentation dataset')
+ parser.add_argument('--netname', type=str, default=None,
+ help='name for network in generated reports')
+ parser.add_argument('--batch_size', type=int, default=5,
+ help='batch size for forward pass')
+ parser.add_argument('--size', type=int, default=200,
+ help='number of images to test')
+ parser.add_argument('--no-cuda', action='store_true', default=False,
+ help='disables CUDA usage')
+ parser.add_argument('--quiet', action='store_true', default=False,
+ help='silences console output')
+ if len(sys.argv) == 1:
+ parser.print_usage(sys.stderr)
+ sys.exit(1)
+ args = parser.parse_args()
+
+ # Set up console output
+ verbose_progress(not args.quiet)
+
+ # Speed up pytorch
+ torch.backends.cudnn.benchmark = True
+
+ # Set up CUDA
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
+ if args.cuda:
+ torch.backends.cudnn.benchmark = True
+
+ # Take defaults for model constructor etc from dissect.json settings.
+ with open(os.path.join(args.outdir, 'dissect.json')) as f:
+ dissection = EasyDict(json.load(f))
+ if args.model is None:
+ args.model = dissection.settings.model
+ if args.pthfile is None:
+ args.pthfile = dissection.settings.pthfile
+ if args.segmenter is None:
+ args.segmenter = dissection.settings.segmenter
+
+ # Instantiate generator
+ model = create_instrumented_model(args, gen=True, edit=True)
+ if model is None:
+ print('No model specified')
+ sys.exit(1)
+
+ # Instantiate model
+ device = next(model.parameters()).device
+ input_shape = model.input_shape
+
+ # 4d input if convolutional, 2d input if first layer is linear.
+ raw_sample = standard_z_sample(args.size, input_shape[1], seed=2).view(
+ (args.size,) + input_shape[1:])
+ dataset = TensorDataset(raw_sample)
+
+ # Create the segmenter
+ segmenter = autoimport_eval(args.segmenter)
+
+ # Now do the actual work.
+ labelnames, catnames = (
+ segmenter.get_label_and_category_names(dataset))
+ label_category = [catnames.index(c) if c in catnames else 0
+ for l, c in labelnames]
+ labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}
+
+ segloader = torch.utils.data.DataLoader(dataset,
+ batch_size=args.batch_size, num_workers=10,
+ pin_memory=(device.type == 'cuda'))
+
+ # Index the dissection layers by layer name.
+ dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}
+
+ # First, collect a baseline
+ for l in model.ablation:
+ model.ablation[l] = None
+
+ # For each sort-order, do an ablation
+ progress = default_progress()
+ for classname in progress(args.classes):
+ post_progress(c=classname)
+ for layername in progress(model.ablation):
+ post_progress(l=layername)
+ rankname = '%s-%s' % (classname, args.metric)
+ classnum = labelnum_from_name[classname]
+ try:
+ ranking = next(r for r in dissect_layer[layername].rankings
+ if r.name == rankname)
+ except:
+ print('%s not found' % rankname)
+ sys.exit(1)
+ ordering = numpy.argsort(ranking.score)
+ # Check if already done
+ ablationdir = os.path.join(args.outdir, layername, 'pixablation')
+ if os.path.isfile(os.path.join(ablationdir, '%s.json'%rankname)):
+ with open(os.path.join(ablationdir, '%s.json'%rankname)) as f:
+ data = EasyDict(json.load(f))
+ # If the unit ordering is not the same, something is wrong
+ if not all(a == o
+ for a, o in zip(data.ablation_units, ordering)):
+ continue
+ if len(data.ablation_effects) >= args.unitcount:
+ continue # file already done.
+ measurements = data.ablation_effects
+ measurements = measure_ablation(segmenter, segloader,
+ model, classnum, layername, ordering[:args.unitcount])
+ measurements = measurements.cpu().numpy().tolist()
+ os.makedirs(ablationdir, exist_ok=True)
+ with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f:
+ json.dump(dict(
+ classname=classname,
+ classnum=classnum,
+ baseline=measurements[0],
+ layer=layername,
+ metric=args.metric,
+ ablation_units=ordering.tolist(),
+ ablation_effects=measurements[1:]), f)
+
+def measure_ablation(segmenter, loader, model, classnum, layer, ordering):
+ total_bincount = 0
+ data_size = 0
+ device = next(model.parameters()).device
+ progress = default_progress()
+ for l in model.ablation:
+ model.ablation[l] = None
+ feature_units = model.feature_shape[layer][1]
+ feature_shape = model.feature_shape[layer][2:]
+ repeats = len(ordering)
+ total_scores = torch.zeros(repeats + 1)
+ for i, batch in enumerate(progress(loader)):
+ z_batch = batch[0]
+ model.ablation[layer] = None
+ tensor_images = model(z_batch.to(device))
+ seg = segmenter.segment_batch(tensor_images, downsample=2)
+ mask = (seg == classnum).max(1)[0]
+ downsampled_seg = torch.nn.functional.adaptive_avg_pool2d(
+ mask.float()[:,None,:,:], feature_shape)[:,0,:,:]
+ total_scores[0] += downsampled_seg.sum().cpu()
+ # Now we need to do an intervention for every location
+ # that had a nonzero downsampled_seg, if any.
+ interventions_needed = downsampled_seg.nonzero()
+ location_count = len(interventions_needed)
+ if location_count == 0:
+ continue
+ interventions_needed = interventions_needed.repeat(repeats, 1)
+ inter_z = batch[0][interventions_needed[:,0]].to(device)
+ inter_chan = torch.zeros(repeats, location_count, feature_units,
+ device=device)
+ for j, u in enumerate(ordering):
+ inter_chan[j:, :, u] = 1
+ inter_chan = inter_chan.view(len(inter_z), feature_units)
+ inter_loc = interventions_needed[:,1:]
+ scores = torch.zeros(len(inter_z))
+ batch_size = len(batch[0])
+ for j in range(0, len(inter_z), batch_size):
+ ibz = inter_z[j:j+batch_size]
+ ibl = inter_loc[j:j+batch_size].t()
+ imask = torch.zeros((len(ibz),) + feature_shape, device=ibz.device)
+ imask[(torch.arange(len(ibz)),) + tuple(ibl)] = 1
+ ibc = inter_chan[j:j+batch_size]
+ model.ablation[layer] = (
+ imask.float()[:,None,:,:] * ibc[:,:,None,None])
+ tensor_images = model(ibz)
+ seg = segmenter.segment_batch(tensor_images, downsample=2)
+ mask = (seg == classnum).max(1)[0]
+ downsampled_iseg = torch.nn.functional.adaptive_avg_pool2d(
+ mask.float()[:,None,:,:], feature_shape)[:,0,:,:]
+ scores[j:j+batch_size] = downsampled_iseg[
+ (torch.arange(len(ibz)),) + tuple(ibl)]
+ scores = scores.view(repeats, location_count).sum(1)
+ total_scores[1:] += scores
+ return total_scores
+
+def count_segments(segmenter, loader, model):
+ total_bincount = 0
+ data_size = 0
+ progress = default_progress()
+ for i, batch in enumerate(progress(loader)):
+ tensor_images = model(z_batch.to(device))
+ seg = segmenter.segment_batch(tensor_images, downsample=2)
+ bc = (seg + index[:, None, None, None] * self.num_classes).view(-1
+ ).bincount(minlength=z_batch.shape[0] * self.num_classes)
+ data_size += seg.shape[0] * seg.shape[2] * seg.shape[3]
+ total_bincount += batch_label_counts.float().sum(0)
+ normalized_bincount = total_bincount / data_size
+ return normalized_bincount
+
+if __name__ == '__main__':
+ main()
diff --git a/netdissect/fullablate.py b/netdissect/fullablate.py
new file mode 100644
index 0000000000000000000000000000000000000000..f92d2c514c0b92b3f33653c5b53198c9fd09cb80
--- /dev/null
+++ b/netdissect/fullablate.py
@@ -0,0 +1,235 @@
+import torch, sys, os, argparse, textwrap, numbers, numpy, json, PIL
+from torchvision import transforms
+from torch.utils.data import TensorDataset
+from netdissect.progress import default_progress, post_progress, desc_progress
+from netdissect.progress import verbose_progress, print_progress
+from netdissect.nethook import edit_layers
+from netdissect.zdataset import standard_z_sample
+from netdissect.autoeval import autoimport_eval
+from netdissect.easydict import EasyDict
+from netdissect.modelconfig import create_instrumented_model
+
+help_epilog = '''\
+Example:
+
+python -m netdissect.evalablate \
+ --segmenter "netdissect.GanImageSegmenter(segvocab='lowres', segsizes=[160,288], segdiv='quad')" \
+ --model "proggan.from_pth_file('models/lsun_models/${SCENE}_lsun.pth')" \
+ --outdir dissect/dissectdir \
+ --classname tree \
+ --layer layer4 \
+ --size 1000
+
+Output layout:
+dissectdir/layer5/ablation/mirror-iqr.json
+{ class: "mirror",
+ classnum: 43,
+ pixel_total: 41342300,
+ class_pixels: 1234531,
+ layer: "layer5",
+ ranking: "mirror-iqr",
+ ablation_units: [341, 23, 12, 142, 83, ...]
+ ablation_pixels: [143242, 132344, 429931, ...]
+}
+
+'''
+
+def main():
+ # Training settings
+ def strpair(arg):
+ p = tuple(arg.split(':'))
+ if len(p) == 1:
+ p = p + p
+ return p
+
+ parser = argparse.ArgumentParser(description='Ablation eval',
+ epilog=textwrap.dedent(help_epilog),
+ formatter_class=argparse.RawDescriptionHelpFormatter)
+ parser.add_argument('--model', type=str, default=None,
+ help='constructor for the model to test')
+ parser.add_argument('--pthfile', type=str, default=None,
+ help='filename of .pth file for the model')
+ parser.add_argument('--outdir', type=str, default='dissect', required=True,
+ help='directory for dissection output')
+ parser.add_argument('--layer', type=strpair,
+ help='space-separated list of layer names to edit' +
+ ', in the form layername[:reportedname]')
+ parser.add_argument('--classname', type=str,
+ help='class name to ablate')
+ parser.add_argument('--metric', type=str, default='iou',
+ help='ordering metric for selecting units')
+ parser.add_argument('--unitcount', type=int, default=30,
+ help='number of units to ablate')
+ parser.add_argument('--segmenter', type=str,
+ help='directory containing segmentation dataset')
+ parser.add_argument('--netname', type=str, default=None,
+ help='name for network in generated reports')
+ parser.add_argument('--batch_size', type=int, default=25,
+ help='batch size for forward pass')
+ parser.add_argument('--mixed_units', action='store_true', default=False,
+ help='true to keep alpha for non-zeroed units')
+ parser.add_argument('--size', type=int, default=200,
+ help='number of images to test')
+ parser.add_argument('--no-cuda', action='store_true', default=False,
+ help='disables CUDA usage')
+ parser.add_argument('--quiet', action='store_true', default=False,
+ help='silences console output')
+ if len(sys.argv) == 1:
+ parser.print_usage(sys.stderr)
+ sys.exit(1)
+ args = parser.parse_args()
+
+ # Set up console output
+ verbose_progress(not args.quiet)
+
+ # Speed up pytorch
+ torch.backends.cudnn.benchmark = True
+
+ # Set up CUDA
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
+ if args.cuda:
+ torch.backends.cudnn.benchmark = True
+
+ # Take defaults for model constructor etc from dissect.json settings.
+ with open(os.path.join(args.outdir, 'dissect.json')) as f:
+ dissection = EasyDict(json.load(f))
+ if args.model is None:
+ args.model = dissection.settings.model
+ if args.pthfile is None:
+ args.pthfile = dissection.settings.pthfile
+ if args.segmenter is None:
+ args.segmenter = dissection.settings.segmenter
+ if args.layer is None:
+ args.layer = dissection.settings.layers[0]
+ args.layers = [args.layer]
+
+ # Also load specific analysis
+ layername = args.layer[1]
+ if args.metric == 'iou':
+ summary = dissection
+ else:
+ with open(os.path.join(args.outdir, layername, args.metric,
+ args.classname, 'summary.json')) as f:
+ summary = EasyDict(json.load(f))
+
+ # Instantiate generator
+ model = create_instrumented_model(args, gen=True, edit=True)
+ if model is None:
+ print('No model specified')
+ sys.exit(1)
+
+ # Instantiate model
+ device = next(model.parameters()).device
+ input_shape = model.input_shape
+
+ # 4d input if convolutional, 2d input if first layer is linear.
+ raw_sample = standard_z_sample(args.size, input_shape[1], seed=3).view(
+ (args.size,) + input_shape[1:])
+ dataset = TensorDataset(raw_sample)
+
+ # Create the segmenter
+ segmenter = autoimport_eval(args.segmenter)
+
+ # Now do the actual work.
+ labelnames, catnames = (
+ segmenter.get_label_and_category_names(dataset))
+ label_category = [catnames.index(c) if c in catnames else 0
+ for l, c in labelnames]
+ labelnum_from_name = {n[0]: i for i, n in enumerate(labelnames)}
+
+ segloader = torch.utils.data.DataLoader(dataset,
+ batch_size=args.batch_size, num_workers=10,
+ pin_memory=(device.type == 'cuda'))
+
+ # Index the dissection layers by layer name.
+
+ # First, collect a baseline
+ for l in model.ablation:
+ model.ablation[l] = None
+
+ # For each sort-order, do an ablation
+ progress = default_progress()
+ classname = args.classname
+ classnum = labelnum_from_name[classname]
+
+ # Get iou ranking from dissect.json
+ iou_rankname = '%s-%s' % (classname, 'iou')
+ dissect_layer = {lrec.layer: lrec for lrec in dissection.layers}
+ iou_ranking = next(r for r in dissect_layer[layername].rankings
+ if r.name == iou_rankname)
+
+ # Get trained ranking from summary.json
+ rankname = '%s-%s' % (classname, args.metric)
+ summary_layer = {lrec.layer: lrec for lrec in summary.layers}
+ ranking = next(r for r in summary_layer[layername].rankings
+ if r.name == rankname)
+
+ # Get ordering, first by ranking, then break ties by iou.
+ ordering = [t[2] for t in sorted([(s1, s2, i)
+ for i, (s1, s2) in enumerate(zip(ranking.score, iou_ranking.score))])]
+ values = (-numpy.array(ranking.score))[ordering]
+ if not args.mixed_units:
+ values[...] = 1
+
+ ablationdir = os.path.join(args.outdir, layername, 'fullablation')
+ measurements = measure_full_ablation(segmenter, segloader,
+ model, classnum, layername,
+ ordering[:args.unitcount], values[:args.unitcount])
+ measurements = measurements.cpu().numpy().tolist()
+ os.makedirs(ablationdir, exist_ok=True)
+ with open(os.path.join(ablationdir, '%s.json'%rankname), 'w') as f:
+ json.dump(dict(
+ classname=classname,
+ classnum=classnum,
+ baseline=measurements[0],
+ layer=layername,
+ metric=args.metric,
+ ablation_units=ordering,
+ ablation_values=values.tolist(),
+ ablation_effects=measurements[1:]), f)
+
+def measure_full_ablation(segmenter, loader, model, classnum, layer,
+ ordering, values):
+ '''
+ Quick and easy counting of segmented pixels reduced by ablating units.
+ '''
+ progress = default_progress()
+ device = next(model.parameters()).device
+ feature_units = model.feature_shape[layer][1]
+ feature_shape = model.feature_shape[layer][2:]
+ repeats = len(ordering)
+ total_scores = torch.zeros(repeats + 1)
+ print(ordering)
+ print(values.tolist())
+ with torch.no_grad():
+ for l in model.ablation:
+ model.ablation[l] = None
+ for i, [ibz] in enumerate(progress(loader)):
+ ibz = ibz.cuda()
+ for num_units in progress(range(len(ordering) + 1)):
+ ablation = torch.zeros(feature_units, device=device)
+ ablation[ordering[:num_units]] = torch.tensor(
+ values[:num_units]).to(ablation.device, ablation.dtype)
+ model.ablation[layer] = ablation
+ tensor_images = model(ibz)
+ seg = segmenter.segment_batch(tensor_images, downsample=2)
+ mask = (seg == classnum).max(1)[0]
+ total_scores[num_units] += mask.sum().float().cpu()
+ return total_scores
+
+def count_segments(segmenter, loader, model):
+ total_bincount = 0
+ data_size = 0
+ progress = default_progress()
+ for i, batch in enumerate(progress(loader)):
+ tensor_images = model(z_batch.to(device))
+ seg = segmenter.segment_batch(tensor_images, downsample=2)
+ bc = (seg + index[:, None, None, None] * self.num_classes).view(-1
+ ).bincount(minlength=z_batch.shape[0] * self.num_classes)
+ data_size += seg.shape[0] * seg.shape[2] * seg.shape[3]
+ total_bincount += batch_label_counts.float().sum(0)
+ normalized_bincount = total_bincount / data_size
+ return normalized_bincount
+
+if __name__ == '__main__':
+ main()
diff --git a/netdissect/modelconfig.py b/netdissect/modelconfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0ee37a809ea1bcbd803cd7d4e100e1bb93290c9
--- /dev/null
+++ b/netdissect/modelconfig.py
@@ -0,0 +1,144 @@
+'''
+Original from https://github.com/CSAILVision/GANDissect
+Modified by Erik Härkönen, 29.11.2019
+'''
+
+import numbers
+import torch
+from netdissect.autoeval import autoimport_eval
+from netdissect.progress import print_progress
+from netdissect.nethook import InstrumentedModel
+from netdissect.easydict import EasyDict
+
+def create_instrumented_model(args, **kwargs):
+ '''
+ Creates an instrumented model out of a namespace of arguments that
+ correspond to ArgumentParser command-line args:
+ model: a string to evaluate as a constructor for the model.
+ pthfile: (optional) filename of .pth file for the model.
+ layers: a list of layers to instrument, defaulted if not provided.
+ edit: True to instrument the layers for editing.
+ gen: True for a generator model. One-pixel input assumed.
+ imgsize: For non-generator models, (y, x) dimensions for RGB input.
+ cuda: True to use CUDA.
+
+ The constructed model will be decorated with the following attributes:
+ input_shape: (usually 4d) tensor shape for single-image input.
+ output_shape: 4d tensor shape for output.
+ feature_shape: map of layer names to 4d tensor shape for featuremaps.
+ retained: map of layernames to tensors, filled after every evaluation.
+ ablation: if editing, map of layernames to [0..1] alpha values to fill.
+ replacement: if editing, map of layernames to values to fill.
+
+ When editing, the feature value x will be replaced by:
+ `x = (replacement * ablation) + (x * (1 - ablation))`
+ '''
+
+ args = EasyDict(vars(args), **kwargs)
+
+ # Construct the network
+ if args.model is None:
+ print_progress('No model specified')
+ return None
+ if isinstance(args.model, torch.nn.Module):
+ model = args.model
+ else:
+ model = autoimport_eval(args.model)
+ # Unwrap any DataParallel-wrapped model
+ if isinstance(model, torch.nn.DataParallel):
+ model = next(model.children())
+
+ # Load its state dict
+ meta = {}
+ if getattr(args, 'pthfile', None) is not None:
+ data = torch.load(args.pthfile)
+ if 'state_dict' in data:
+ meta = {}
+ for key in data:
+ if isinstance(data[key], numbers.Number):
+ meta[key] = data[key]
+ data = data['state_dict']
+ submodule = getattr(args, 'submodule', None)
+ if submodule is not None and len(submodule):
+ remove_prefix = submodule + '.'
+ data = { k[len(remove_prefix):]: v for k, v in data.items()
+ if k.startswith(remove_prefix)}
+ if not len(data):
+ print_progress('No submodule %s found in %s' %
+ (submodule, args.pthfile))
+ return None
+ model.load_state_dict(data, strict=not getattr(args, 'unstrict', False))
+
+ # Decide which layers to instrument.
+ if getattr(args, 'layer', None) is not None:
+ args.layers = [args.layer]
+ if getattr(args, 'layers', None) is None:
+ # Skip wrappers with only one named model
+ container = model
+ prefix = ''
+ while len(list(container.named_children())) == 1:
+ name, container = next(container.named_children())
+ prefix += name + '.'
+ # Default to all nontrivial top-level layers except last.
+ args.layers = [prefix + name
+ for name, module in container.named_children()
+ if type(module).__module__ not in [
+ # Skip ReLU and other activations.
+ 'torch.nn.modules.activation',
+ # Skip pooling layers.
+ 'torch.nn.modules.pooling']
+ ][:-1]
+ print_progress('Defaulting to layers: %s' % ' '.join(args.layers))
+
+ # Now wrap the model for instrumentation.
+ model = InstrumentedModel(model)
+ model.meta = meta
+
+ # Instrument the layers.
+ model.retain_layers(args.layers)
+ model.eval()
+ if args.cuda:
+ model.cuda()
+
+ # Annotate input, output, and feature shapes
+ annotate_model_shapes(model,
+ gen=getattr(args, 'gen', False),
+ imgsize=getattr(args, 'imgsize', None),
+ latent_shape=getattr(args, 'latent_shape', None))
+ return model
+
+def annotate_model_shapes(model, gen=False, imgsize=None, latent_shape=None):
+ assert (imgsize is not None) or gen
+
+ # Figure the input shape.
+ if gen:
+ if latent_shape is None:
+ # We can guess a generator's input shape by looking at the model.
+ # Examine first conv in model to determine input feature size.
+ first_layer = [c for c in model.modules()
+ if isinstance(c, (torch.nn.Conv2d, torch.nn.ConvTranspose2d,
+ torch.nn.Linear))][0]
+ # 4d input if convolutional, 2d input if first layer is linear.
+ if isinstance(first_layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
+ input_shape = (1, first_layer.in_channels, 1, 1)
+ else:
+ input_shape = (1, first_layer.in_features)
+ else:
+ # Specify input shape manually
+ input_shape = latent_shape
+ else:
+ # For a classifier, the input image shape is given as an argument.
+ input_shape = (1, 3) + tuple(imgsize)
+
+ # Run the model once to observe feature shapes.
+ device = next(model.parameters()).device
+ dry_run = torch.zeros(input_shape).to(device)
+ with torch.no_grad():
+ output = model(dry_run)
+
+ # Annotate shapes.
+ model.input_shape = input_shape
+ model.feature_shape = { layer: feature.shape
+ for layer, feature in model.retained_features().items() }
+ model.output_shape = output.shape
+ return model
diff --git a/netdissect/nethook.py b/netdissect/nethook.py
new file mode 100644
index 0000000000000000000000000000000000000000..f36e84ee0cae2de2c3be247498408cf66db3ee8f
--- /dev/null
+++ b/netdissect/nethook.py
@@ -0,0 +1,266 @@
+'''
+Utilities for instrumenting a torch model.
+
+InstrumentedModel will wrap a pytorch model and allow hooking
+arbitrary layers to monitor or modify their output directly.
+
+Modified by Erik Härkönen:
+- 29.11.2019: Unhooking bugfix
+- 25.01.2020: Offset edits, removed old API
+'''
+
+import torch, numpy, types
+from collections import OrderedDict
+
+class InstrumentedModel(torch.nn.Module):
+ '''
+ A wrapper for hooking, probing and intervening in pytorch Modules.
+ Example usage:
+
+ ```
+ model = load_my_model()
+ with inst as InstrumentedModel(model):
+ inst.retain_layer(layername)
+ inst.edit_layer(layername, 0.5, target_features)
+ inst.edit_layer(layername, offset=offset_tensor)
+ inst(inputs)
+ original_features = inst.retained_layer(layername)
+ ```
+ '''
+
+ def __init__(self, model):
+ super(InstrumentedModel, self).__init__()
+ self.model = model
+ self._retained = OrderedDict()
+ self._ablation = {}
+ self._replacement = {}
+ self._offset = {}
+ self._hooked_layer = {}
+ self._old_forward = {}
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+ def forward(self, *inputs, **kwargs):
+ return self.model(*inputs, **kwargs)
+
+ def retain_layer(self, layername):
+ '''
+ Pass a fully-qualified layer name (E.g., module.submodule.conv3)
+ to hook that layer and retain its output each time the model is run.
+ A pair (layername, aka) can be provided, and the aka will be used
+ as the key for the retained value instead of the layername.
+ '''
+ self.retain_layers([layername])
+
+ def retain_layers(self, layernames):
+ '''
+ Retains a list of a layers at once.
+ '''
+ self.add_hooks(layernames)
+ for layername in layernames:
+ aka = layername
+ if not isinstance(aka, str):
+ layername, aka = layername
+ if aka not in self._retained:
+ self._retained[aka] = None
+
+ def retained_features(self):
+ '''
+ Returns a dict of all currently retained features.
+ '''
+ return OrderedDict(self._retained)
+
+ def retained_layer(self, aka=None, clear=False):
+ '''
+ Retrieve retained data that was previously hooked by retain_layer.
+ Call this after the model is run. If clear is set, then the
+ retained value will return and also cleared.
+ '''
+ if aka is None:
+ # Default to the first retained layer.
+ aka = next(self._retained.keys().__iter__())
+ result = self._retained[aka]
+ if clear:
+ self._retained[aka] = None
+ return result
+
+ def edit_layer(self, layername, ablation=None, replacement=None, offset=None):
+ '''
+ Pass a fully-qualified layer name (E.g., module.submodule.conv3)
+ to hook that layer and modify its output each time the model is run.
+ The output of the layer will be modified to be a convex combination
+ of the replacement and x interpolated according to the ablation, i.e.:
+ `output = x * (1 - a) + (r * a)`.
+ Additionally or independently, an offset can be added to the output.
+ '''
+ if not isinstance(layername, str):
+ layername, aka = layername
+ else:
+ aka = layername
+
+ # The default ablation if a replacement is specified is 1.0.
+ if ablation is None and replacement is not None:
+ ablation = 1.0
+ self.add_hooks([(layername, aka)])
+ if ablation is not None:
+ self._ablation[aka] = ablation
+ if replacement is not None:
+ self._replacement[aka] = replacement
+ if offset is not None:
+ self._offset[aka] = offset
+ # If needed, could add an arbitrary postprocessing lambda here.
+
+ def remove_edits(self, layername=None, remove_offset=True, remove_replacement=True):
+ '''
+ Removes edits at the specified layer, or removes edits at all layers
+ if no layer name is specified.
+ '''
+ if layername is None:
+ if remove_replacement:
+ self._ablation.clear()
+ self._replacement.clear()
+ if remove_offset:
+ self._offset.clear()
+ return
+
+ if not isinstance(layername, str):
+ layername, aka = layername
+ else:
+ aka = layername
+ if remove_replacement and aka in self._ablation:
+ del self._ablation[aka]
+ if remove_replacement and aka in self._replacement:
+ del self._replacement[aka]
+ if remove_offset and aka in self._offset:
+ del self._offset[aka]
+
+ def add_hooks(self, layernames):
+ '''
+ Sets up a set of layers to be hooked.
+
+ Usually not called directly: use edit_layer or retain_layer instead.
+ '''
+ needed = set()
+ aka_map = {}
+ for name in layernames:
+ aka = name
+ if not isinstance(aka, str):
+ name, aka = name
+ if self._hooked_layer.get(aka, None) != name:
+ aka_map[name] = aka
+ needed.add(name)
+ if not needed:
+ return
+ for name, layer in self.model.named_modules():
+ if name in aka_map:
+ needed.remove(name)
+ aka = aka_map[name]
+ self._hook_layer(layer, name, aka)
+ for name in needed:
+ raise ValueError('Layer %s not found in model' % name)
+
+ def _hook_layer(self, layer, layername, aka):
+ '''
+ Internal method to replace a forward method with a closure that
+ intercepts the call, and tracks the hook so that it can be reverted.
+ '''
+ if aka in self._hooked_layer:
+ raise ValueError('Layer %s already hooked' % aka)
+ if layername in self._old_forward:
+ raise ValueError('Layer %s already hooked' % layername)
+ self._hooked_layer[aka] = layername
+ self._old_forward[layername] = (layer, aka,
+ layer.__dict__.get('forward', None))
+ editor = self
+ original_forward = layer.forward
+ def new_forward(self, *inputs, **kwargs):
+ original_x = original_forward(*inputs, **kwargs)
+ x = editor._postprocess_forward(original_x, aka)
+ return x
+ layer.forward = types.MethodType(new_forward, layer)
+
+ def _unhook_layer(self, aka):
+ '''
+ Internal method to remove a hook, restoring the original forward method.
+ '''
+ if aka not in self._hooked_layer:
+ return
+ layername = self._hooked_layer[aka]
+ layer, check, old_forward = self._old_forward[layername]
+ assert check == aka
+ if old_forward is None:
+ if 'forward' in layer.__dict__:
+ del layer.__dict__['forward']
+ else:
+ layer.forward = old_forward
+ del self._old_forward[layername]
+ del self._hooked_layer[aka]
+ if aka in self._ablation:
+ del self._ablation[aka]
+ if aka in self._replacement:
+ del self._replacement[aka]
+ if aka in self._offset:
+ del self._offset[aka]
+ if aka in self._retained:
+ del self._retained[aka]
+
+ def _postprocess_forward(self, x, aka):
+ '''
+ The internal method called by the hooked layers after they are run.
+ '''
+ # Retain output before edits, if desired.
+ if aka in self._retained:
+ self._retained[aka] = x.detach()
+
+ # Apply replacement edit
+ a = make_matching_tensor(self._ablation, aka, x)
+ if a is not None:
+ x = x * (1 - a)
+ v = make_matching_tensor(self._replacement, aka, x)
+ if v is not None:
+ x += (v * a)
+
+ # Apply offset edit
+ b = make_matching_tensor(self._offset, aka, x)
+ if b is not None:
+ x = x + b
+
+ return x
+
+ def close(self):
+ '''
+ Unhooks all hooked layers in the model.
+ '''
+ for aka in list(self._old_forward.keys()):
+ self._unhook_layer(aka)
+ assert len(self._old_forward) == 0
+
+
+def make_matching_tensor(valuedict, name, data):
+ '''
+ Converts `valuedict[name]` to be a tensor with the same dtype, device,
+ and dimension count as `data`, and caches the converted tensor.
+ '''
+ v = valuedict.get(name, None)
+ if v is None:
+ return None
+ if not isinstance(v, torch.Tensor):
+ # Accept non-torch data.
+ v = torch.from_numpy(numpy.array(v))
+ valuedict[name] = v
+ if not v.device == data.device or not v.dtype == data.dtype:
+ # Ensure device and type matches.
+ assert not v.requires_grad, '%s wrong device or type' % (name)
+ v = v.to(device=data.device, dtype=data.dtype)
+ valuedict[name] = v
+ if len(v.shape) < len(data.shape):
+ # Ensure dimensions are unsqueezed as needed.
+ assert not v.requires_grad, '%s wrong dimensions' % (name)
+ v = v.view((1,) + tuple(v.shape) +
+ (1,) * (len(data.shape) - len(v.shape) - 1))
+ valuedict[name] = v
+ return v
diff --git a/netdissect/parallelfolder.py b/netdissect/parallelfolder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a741691569a7c85e96d3b3d9be12b40d508f0044
--- /dev/null
+++ b/netdissect/parallelfolder.py
@@ -0,0 +1,118 @@
+'''
+Variants of pytorch's ImageFolder for loading image datasets with more
+information, such as parallel feature channels in separate files,
+cached files with lists of filenames, etc.
+'''
+
+import os, torch, re
+import torch.utils.data as data
+from torchvision.datasets.folder import default_loader
+from PIL import Image
+from collections import OrderedDict
+from .progress import default_progress
+
+def grayscale_loader(path):
+ with open(path, 'rb') as f:
+ return Image.open(f).convert('L')
+
+class ParallelImageFolders(data.Dataset):
+ """
+ A data loader that looks for parallel image filenames, for example
+
+ photo1/park/004234.jpg
+ photo1/park/004236.jpg
+ photo1/park/004237.jpg
+
+ photo2/park/004234.png
+ photo2/park/004236.png
+ photo2/park/004237.png
+ """
+ def __init__(self, image_roots,
+ transform=None,
+ loader=default_loader,
+ stacker=None,
+ intersection=False,
+ verbose=None,
+ size=None):
+ self.image_roots = image_roots
+ self.images = make_parallel_dataset(image_roots,
+ intersection=intersection, verbose=verbose)
+ if len(self.images) == 0:
+ raise RuntimeError("Found 0 images within: %s" % image_roots)
+ if size is not None:
+ self.image = self.images[:size]
+ if transform is not None and not hasattr(transform, '__iter__'):
+ transform = [transform for _ in image_roots]
+ self.transforms = transform
+ self.stacker = stacker
+ self.loader = loader
+
+ def __getitem__(self, index):
+ paths = self.images[index]
+ sources = [self.loader(path) for path in paths]
+ # Add a common shared state dict to allow random crops/flips to be
+ # coordinated.
+ shared_state = {}
+ for s in sources:
+ s.shared_state = shared_state
+ if self.transforms is not None:
+ sources = [transform(source)
+ for source, transform in zip(sources, self.transforms)]
+ if self.stacker is not None:
+ sources = self.stacker(sources)
+ else:
+ sources = tuple(sources)
+ return sources
+
+ def __len__(self):
+ return len(self.images)
+
+def is_npy_file(path):
+ return path.endswith('.npy') or path.endswith('.NPY')
+
+def is_image_file(path):
+ return None != re.search(r'\.(jpe?g|png)$', path, re.IGNORECASE)
+
+def walk_image_files(rootdir, verbose=None):
+ progress = default_progress(verbose)
+ indexfile = '%s.txt' % rootdir
+ if os.path.isfile(indexfile):
+ basedir = os.path.dirname(rootdir)
+ with open(indexfile) as f:
+ result = sorted([os.path.join(basedir, line.strip())
+ for line in progress(f.readlines(),
+ desc='Reading %s' % os.path.basename(indexfile))])
+ return result
+ result = []
+ for dirname, _, fnames in sorted(progress(os.walk(rootdir),
+ desc='Walking %s' % os.path.basename(rootdir))):
+ for fname in sorted(fnames):
+ if is_image_file(fname) or is_npy_file(fname):
+ result.append(os.path.join(dirname, fname))
+ return result
+
+def make_parallel_dataset(image_roots, intersection=False, verbose=None):
+ """
+ Returns [(img1, img2), (img1, img2)..]
+ """
+ image_roots = [os.path.expanduser(d) for d in image_roots]
+ image_sets = OrderedDict()
+ for j, root in enumerate(image_roots):
+ for path in walk_image_files(root, verbose=verbose):
+ key = os.path.splitext(os.path.relpath(path, root))[0]
+ if key not in image_sets:
+ image_sets[key] = []
+ if not intersection and len(image_sets[key]) != j:
+ raise RuntimeError(
+ 'Images not parallel: %s missing from one dir' % (key))
+ image_sets[key].append(path)
+ tuples = []
+ for key, value in image_sets.items():
+ if len(value) != len(image_roots):
+ if intersection:
+ continue
+ else:
+ raise RuntimeError(
+ 'Images not parallel: %s missing from one dir' % (key))
+ tuples.append(tuple(value))
+ return tuples
diff --git a/netdissect/pidfile.py b/netdissect/pidfile.py
new file mode 100644
index 0000000000000000000000000000000000000000..96a66814326bad444606ad829307fe225f4135e1
--- /dev/null
+++ b/netdissect/pidfile.py
@@ -0,0 +1,81 @@
+'''
+Utility for simple distribution of work on multiple processes, by
+making sure only one process is working on a job at once.
+'''
+
+import os, errno, socket, atexit, time, sys
+
+def exit_if_job_done(directory):
+ if pidfile_taken(os.path.join(directory, 'lockfile.pid'), verbose=True):
+ sys.exit(0)
+ if os.path.isfile(os.path.join(directory, 'done.txt')):
+ with open(os.path.join(directory, 'done.txt')) as f:
+ msg = f.read()
+ print(msg)
+ sys.exit(0)
+
+def mark_job_done(directory):
+ with open(os.path.join(directory, 'done.txt'), 'w') as f:
+ f.write('Done by %d@%s %s at %s' %
+ (os.getpid(), socket.gethostname(),
+ os.getenv('STY', ''),
+ time.strftime('%c')))
+
+def pidfile_taken(path, verbose=False):
+ '''
+ Usage. To grab an exclusive lock for the remaining duration of the
+ current process (and exit if another process already has the lock),
+ do this:
+
+ if pidfile_taken('job_423/lockfile.pid', verbose=True):
+ sys.exit(0)
+
+ To do a batch of jobs, just run a script that does them all on
+ each available machine, sharing a network filesystem. When each
+ job grabs a lock, then this will automatically distribute the
+ jobs so that each one is done just once on one machine.
+ '''
+
+ # Try to create the file exclusively and write my pid into it.
+ try:
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ # If we cannot because there was a race, yield the conflicter.
+ conflicter = 'race'
+ try:
+ with open(path, 'r') as lockfile:
+ conflicter = lockfile.read().strip() or 'empty'
+ except:
+ pass
+ if verbose:
+ print('%s held by %s' % (path, conflicter))
+ return conflicter
+ else:
+ # Other problems get an exception.
+ raise
+ # Register to delete this file on exit.
+ lockfile = os.fdopen(fd, 'r+')
+ atexit.register(delete_pidfile, lockfile, path)
+ # Write my pid into the open file.
+ lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(),
+ os.getenv('STY', '')))
+ lockfile.flush()
+ os.fsync(lockfile)
+ # Return 'None' to say there was not a conflict.
+ return None
+
+def delete_pidfile(lockfile, path):
+ '''
+ Runs at exit after pidfile_taken succeeds.
+ '''
+ if lockfile is not None:
+ try:
+ lockfile.close()
+ except:
+ pass
+ try:
+ os.unlink(path)
+ except:
+ pass
diff --git a/netdissect/plotutil.py b/netdissect/plotutil.py
new file mode 100644
index 0000000000000000000000000000000000000000..187bcb9d5615c8ec51a43148b011c06b8ed6aff7
--- /dev/null
+++ b/netdissect/plotutil.py
@@ -0,0 +1,61 @@
+import matplotlib.pyplot as plt
+import numpy
+
+def plot_tensor_images(data, **kwargs):
+ data = ((data + 1) / 2 * 255).permute(0, 2, 3, 1).byte().cpu().numpy()
+ width = int(numpy.ceil(numpy.sqrt(data.shape[0])))
+ height = int(numpy.ceil(data.shape[0] / float(width)))
+ kwargs = dict(kwargs)
+ margin = 0.01
+ if 'figsize' not in kwargs:
+ # Size figure to one display pixel per data pixel
+ dpi = plt.rcParams['figure.dpi']
+ kwargs['figsize'] = (
+ (1 + margin) * (width * data.shape[2] / dpi),
+ (1 + margin) * (height * data.shape[1] / dpi))
+ f, axarr = plt.subplots(height, width, **kwargs)
+ if len(numpy.shape(axarr)) == 0:
+ axarr = numpy.array([[axarr]])
+ if len(numpy.shape(axarr)) == 1:
+ axarr = axarr[None,:]
+ for i, im in enumerate(data):
+ ax = axarr[i // width, i % width]
+ ax.imshow(data[i])
+ ax.axis('off')
+ for i in range(i, width * height):
+ ax = axarr[i // width, i % width]
+ ax.axis('off')
+ plt.subplots_adjust(wspace=margin, hspace=margin,
+ left=0, right=1, bottom=0, top=1)
+ plt.show()
+
+def plot_max_heatmap(data, shape=None, **kwargs):
+ if shape is None:
+ shape = data.shape[2:]
+ data = data.max(1)[0].cpu().numpy()
+ vmin = data.min()
+ vmax = data.max()
+ width = int(numpy.ceil(numpy.sqrt(data.shape[0])))
+ height = int(numpy.ceil(data.shape[0] / float(width)))
+ kwargs = dict(kwargs)
+ margin = 0.01
+ if 'figsize' not in kwargs:
+ # Size figure to one display pixel per data pixel
+ dpi = plt.rcParams['figure.dpi']
+ kwargs['figsize'] = (
+ width * shape[1] / dpi, height * shape[0] / dpi)
+ f, axarr = plt.subplots(height, width, **kwargs)
+ if len(numpy.shape(axarr)) == 0:
+ axarr = numpy.array([[axarr]])
+ if len(numpy.shape(axarr)) == 1:
+ axarr = axarr[None,:]
+ for i, im in enumerate(data):
+ ax = axarr[i // width, i % width]
+ img = ax.imshow(data[i], vmin=vmin, vmax=vmax, cmap='hot')
+ ax.axis('off')
+ for i in range(i, width * height):
+ ax = axarr[i // width, i % width]
+ ax.axis('off')
+ plt.subplots_adjust(wspace=margin, hspace=margin,
+ left=0, right=1, bottom=0, top=1)
+ plt.show()
diff --git a/netdissect/proggan.py b/netdissect/proggan.py
new file mode 100644
index 0000000000000000000000000000000000000000..e37ae15f373ef6ad14279bb581042434c5563539
--- /dev/null
+++ b/netdissect/proggan.py
@@ -0,0 +1,299 @@
+import torch, numpy, itertools
+import torch.nn as nn
+from collections import OrderedDict
+
+
+def print_network(net, verbose=False):
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('Total number of parameters: {:3.3f} M'.format(num_params / 1e6))
+
+
+def from_pth_file(filename):
+ '''
+ Instantiate from a pth file.
+ '''
+ state_dict = torch.load(filename)
+ if 'state_dict' in state_dict:
+ state_dict = state_dict['state_dict']
+ # Convert old version of parameter names
+ if 'features.0.conv.weight' in state_dict:
+ state_dict = state_dict_from_old_pt_dict(state_dict)
+ sizes = sizes_from_state_dict(state_dict)
+ result = ProgressiveGenerator(sizes=sizes)
+ result.load_state_dict(state_dict)
+ return result
+
+###############################################################################
+# Modules
+###############################################################################
+
+class ProgressiveGenerator(nn.Sequential):
+ def __init__(self, resolution=None, sizes=None, modify_sequence=None,
+ output_tanh=False):
+ '''
+ A pytorch progessive GAN generator that can be converted directly
+ from either a tensorflow model or a theano model. It consists of
+ a sequence of convolutional layers, organized in pairs, with an
+ upsampling and reduction of channels at every other layer; and
+ then finally followed by an output layer that reduces it to an
+ RGB [-1..1] image.
+
+ The network can be given more layers to increase the output
+ resolution. The sizes argument indicates the fieature depth at
+ each upsampling, starting with the input z: [input-dim, 4x4-depth,
+ 8x8-depth, 16x16-depth...]. The output dimension is 2 * 2**len(sizes)
+
+ Some default architectures can be selected by supplying the
+ resolution argument instead.
+
+ The optional modify_sequence function can be used to transform the
+ sequence of layers before the network is constructed.
+
+ If output_tanh is set to True, the network applies a tanh to clamp
+ the output to [-1,1] before output; otherwise the output is unclamped.
+ '''
+ assert (resolution is None) != (sizes is None)
+ if sizes is None:
+ sizes = {
+ 8: [512, 512, 512],
+ 16: [512, 512, 512, 512],
+ 32: [512, 512, 512, 512, 256],
+ 64: [512, 512, 512, 512, 256, 128],
+ 128: [512, 512, 512, 512, 256, 128, 64],
+ 256: [512, 512, 512, 512, 256, 128, 64, 32],
+ 1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16]
+ }[resolution]
+ # Follow the schedule of upsampling given by sizes.
+ # layers are called: layer1, layer2, etc; then output_128x128
+ sequence = []
+ def add_d(layer, name=None):
+ if name is None:
+ name = 'layer%d' % (len(sequence) + 1)
+ sequence.append((name, layer))
+ add_d(NormConvBlock(sizes[0], sizes[1], kernel_size=4, padding=3))
+ add_d(NormConvBlock(sizes[1], sizes[1], kernel_size=3, padding=1))
+ for i, (si, so) in enumerate(zip(sizes[1:-1], sizes[2:])):
+ add_d(NormUpscaleConvBlock(si, so, kernel_size=3, padding=1))
+ add_d(NormConvBlock(so, so, kernel_size=3, padding=1))
+ # Create an output layer. During training, the progressive GAN
+ # learns several such output layers for various resolutions; we
+ # just include the last (highest resolution) one.
+ dim = 4 * (2 ** (len(sequence) // 2 - 1))
+ add_d(OutputConvBlock(sizes[-1], tanh=output_tanh),
+ name='output_%dx%d' % (dim, dim))
+ # Allow the sequence to be modified
+ if modify_sequence is not None:
+ sequence = modify_sequence(sequence)
+ super().__init__(OrderedDict(sequence))
+
+ def forward(self, x):
+ # Convert vector input to 1x1 featuremap.
+ x = x.view(x.shape[0], x.shape[1], 1, 1)
+ return super().forward(x)
+
+class PixelNormLayer(nn.Module):
+ def __init__(self):
+ super(PixelNormLayer, self).__init__()
+
+ def forward(self, x):
+ return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
+
+class DoubleResolutionLayer(nn.Module):
+ def forward(self, x):
+ x = nn.functional.interpolate(x, scale_factor=2, mode='nearest')
+ return x
+
+class WScaleLayer(nn.Module):
+ def __init__(self, size, fan_in, gain=numpy.sqrt(2)):
+ super(WScaleLayer, self).__init__()
+ self.scale = gain / numpy.sqrt(fan_in) # No longer a parameter
+ self.b = nn.Parameter(torch.randn(size))
+ self.size = size
+
+ def forward(self, x):
+ x_size = x.size()
+ x = x * self.scale + self.b.view(1, -1, 1, 1).expand(
+ x_size[0], self.size, x_size[2], x_size[3])
+ return x
+
+class NormConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, padding):
+ super(NormConvBlock, self).__init__()
+ self.norm = PixelNormLayer()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, kernel_size, 1, padding, bias=False)
+ self.wscale = WScaleLayer(out_channels, in_channels,
+ gain=numpy.sqrt(2) / kernel_size)
+ self.relu = nn.LeakyReLU(inplace=True, negative_slope=0.2)
+
+ def forward(self, x):
+ x = self.norm(x)
+ x = self.conv(x)
+ x = self.relu(self.wscale(x))
+ return x
+
+class NormUpscaleConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, padding):
+ super(NormUpscaleConvBlock, self).__init__()
+ self.norm = PixelNormLayer()
+ self.up = DoubleResolutionLayer()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, kernel_size, 1, padding, bias=False)
+ self.wscale = WScaleLayer(out_channels, in_channels,
+ gain=numpy.sqrt(2) / kernel_size)
+ self.relu = nn.LeakyReLU(inplace=True, negative_slope=0.2)
+
+ def forward(self, x):
+ x = self.norm(x)
+ x = self.up(x)
+ x = self.conv(x)
+ x = self.relu(self.wscale(x))
+ return x
+
+class OutputConvBlock(nn.Module):
+ def __init__(self, in_channels, tanh=False):
+ super().__init__()
+ self.norm = PixelNormLayer()
+ self.conv = nn.Conv2d(
+ in_channels, 3, kernel_size=1, padding=0, bias=False)
+ self.wscale = WScaleLayer(3, in_channels, gain=1)
+ self.clamp = nn.Hardtanh() if tanh else (lambda x: x)
+
+ def forward(self, x):
+ x = self.norm(x)
+ x = self.conv(x)
+ x = self.wscale(x)
+ x = self.clamp(x)
+ return x
+
+###############################################################################
+# Conversion
+###############################################################################
+
+def from_tf_parameters(parameters):
+ '''
+ Instantiate from tensorflow variables.
+ '''
+ state_dict = state_dict_from_tf_parameters(parameters)
+ sizes = sizes_from_state_dict(state_dict)
+ result = ProgressiveGenerator(sizes=sizes)
+ result.load_state_dict(state_dict)
+ return result
+
+def from_old_pt_dict(parameters):
+ '''
+ Instantiate from old pytorch state dict.
+ '''
+ state_dict = state_dict_from_old_pt_dict(parameters)
+ sizes = sizes_from_state_dict(state_dict)
+ result = ProgressiveGenerator(sizes=sizes)
+ result.load_state_dict(state_dict)
+ return result
+
+def sizes_from_state_dict(params):
+ '''
+ In a progressive GAN, the number of channels can change after each
+ upsampling. This function reads the state dict to figure the
+ number of upsamplings and the channel depth of each filter.
+ '''
+ sizes = []
+ for i in itertools.count():
+ pt_layername = 'layer%d' % (i + 1)
+ try:
+ weight = params['%s.conv.weight' % pt_layername]
+ except KeyError:
+ break
+ if i == 0:
+ sizes.append(weight.shape[1])
+ if i % 2 == 0:
+ sizes.append(weight.shape[0])
+ return sizes
+
+def state_dict_from_tf_parameters(parameters):
+ '''
+ Conversion from tensorflow parameters
+ '''
+ def torch_from_tf(data):
+ return torch.from_numpy(data.eval())
+
+ params = dict(parameters)
+ result = {}
+ sizes = []
+ for i in itertools.count():
+ resolution = 4 * (2 ** (i // 2))
+ # Translate parameter names. For example:
+ # 4x4/Dense/weight -> layer1.conv.weight
+ # 32x32/Conv0_up/weight -> layer7.conv.weight
+ # 32x32/Conv1/weight -> layer8.conv.weight
+ tf_layername = '%dx%d/%s' % (resolution, resolution,
+ 'Dense' if i == 0 else 'Conv' if i == 1 else
+ 'Conv0_up' if i % 2 == 0 else 'Conv1')
+ pt_layername = 'layer%d' % (i + 1)
+ # Stop looping when we run out of parameters.
+ try:
+ weight = torch_from_tf(params['%s/weight' % tf_layername])
+ except KeyError:
+ break
+ # Transpose convolution weights into pytorch format.
+ if i == 0:
+ # Convert dense layer to 4x4 convolution
+ weight = weight.view(weight.shape[0], weight.shape[1] // 16,
+ 4, 4).permute(1, 0, 2, 3).flip(2, 3)
+ sizes.append(weight.shape[0])
+ elif i % 2 == 0:
+ # Convert inverse convolution to convolution
+ weight = weight.permute(2, 3, 0, 1).flip(2, 3)
+ else:
+ # Ordinary Conv2d conversion.
+ weight = weight.permute(3, 2, 0, 1)
+ sizes.append(weight.shape[1])
+ result['%s.conv.weight' % (pt_layername)] = weight
+ # Copy bias vector.
+ bias = torch_from_tf(params['%s/bias' % tf_layername])
+ result['%s.wscale.b' % (pt_layername)] = bias
+ # Copy just finest-grained ToRGB output layers. For example:
+ # ToRGB_lod0/weight -> output.conv.weight
+ i -= 1
+ resolution = 4 * (2 ** (i // 2))
+ tf_layername = 'ToRGB_lod0'
+ pt_layername = 'output_%dx%d' % (resolution, resolution)
+ result['%s.conv.weight' % pt_layername] = torch_from_tf(
+ params['%s/weight' % tf_layername]).permute(3, 2, 0, 1)
+ result['%s.wscale.b' % pt_layername] = torch_from_tf(
+ params['%s/bias' % tf_layername])
+ # Return parameters
+ return result
+
+def state_dict_from_old_pt_dict(params):
+ '''
+ Conversion from the old pytorch model layer names.
+ '''
+ result = {}
+ sizes = []
+ for i in itertools.count():
+ old_layername = 'features.%d' % i
+ pt_layername = 'layer%d' % (i + 1)
+ try:
+ weight = params['%s.conv.weight' % (old_layername)]
+ except KeyError:
+ break
+ if i == 0:
+ sizes.append(weight.shape[0])
+ if i % 2 == 0:
+ sizes.append(weight.shape[1])
+ result['%s.conv.weight' % (pt_layername)] = weight
+ result['%s.wscale.b' % (pt_layername)] = params[
+ '%s.wscale.b' % (old_layername)]
+ # Copy the output layers.
+ i -= 1
+ resolution = 4 * (2 ** (i // 2))
+ pt_layername = 'output_%dx%d' % (resolution, resolution)
+ result['%s.conv.weight' % pt_layername] = params['output.conv.weight']
+ result['%s.wscale.b' % pt_layername] = params['output.wscale.b']
+ # Return parameters and also network architecture sizes.
+ return result
+
diff --git a/netdissect/progress.py b/netdissect/progress.py
new file mode 100644
index 0000000000000000000000000000000000000000..702b24cf6668e6caad38d3c315eb658b6af4d230
--- /dev/null
+++ b/netdissect/progress.py
@@ -0,0 +1,98 @@
+'''
+Utilities for showing progress bars, controlling default verbosity, etc.
+'''
+
+# If the tqdm package is not available, then do not show progress bars;
+# just connect print_progress to print.
+try:
+ from tqdm import tqdm, tqdm_notebook
+except:
+ tqdm = None
+
+default_verbosity = False
+
+def verbose_progress(verbose):
+ '''
+ Sets default verbosity level. Set to True to see progress bars.
+ '''
+ global default_verbosity
+ default_verbosity = verbose
+
+def tqdm_terminal(it, *args, **kwargs):
+ '''
+ Some settings for tqdm that make it run better in resizable terminals.
+ '''
+ return tqdm(it, *args, dynamic_ncols=True, ascii=True,
+ leave=(not nested_tqdm()), **kwargs)
+
+def in_notebook():
+ '''
+ True if running inside a Jupyter notebook.
+ '''
+ # From https://stackoverflow.com/a/39662359/265298
+ try:
+ shell = get_ipython().__class__.__name__
+ if shell == 'ZMQInteractiveShell':
+ return True # Jupyter notebook or qtconsole
+ elif shell == 'TerminalInteractiveShell':
+ return False # Terminal running IPython
+ else:
+ return False # Other type (?)
+ except NameError:
+ return False # Probably standard Python interpreter
+
+def nested_tqdm():
+ '''
+ True if there is an active tqdm progress loop on the stack.
+ '''
+ return hasattr(tqdm, '_instances') and len(tqdm._instances) > 0
+
+def post_progress(**kwargs):
+ '''
+ When within a progress loop, post_progress(k=str) will display
+ the given k=str status on the right-hand-side of the progress
+ status bar. If not within a visible progress bar, does nothing.
+ '''
+ if nested_tqdm():
+ innermost = max(tqdm._instances, key=lambda x: x.pos)
+ innermost.set_postfix(**kwargs)
+
+def desc_progress(desc):
+ '''
+ When within a progress loop, desc_progress(str) changes the
+ left-hand-side description of the loop toe the given description.
+ '''
+ if nested_tqdm():
+ innermost = max(tqdm._instances, key=lambda x: x.pos)
+ innermost.set_description(desc)
+
+def print_progress(*args):
+ '''
+ When within a progress loop, post_progress(k=str) will display
+ the given k=str status on the right-hand-side of the progress
+ status bar. If not within a visible progress bar, does nothing.
+ '''
+ if default_verbosity:
+ printfn = print if tqdm is None else tqdm.write
+ printfn(' '.join(str(s) for s in args))
+
+def default_progress(verbose=None, iftop=False):
+ '''
+ Returns a progress function that can wrap iterators to print
+ progress messages, if verbose is True.
+
+ If verbose is False or if iftop is True and there is already
+ a top-level tqdm loop being reported, then a quiet non-printing
+ identity function is returned.
+
+ verbose can also be set to a spefific progress function rather
+ than True, and that function will be used.
+ '''
+ global default_verbosity
+ if verbose is None:
+ verbose = default_verbosity
+ if not verbose or (iftop and nested_tqdm()) or tqdm is None:
+ return lambda x, *args, **kw: x
+ if verbose == True:
+ return tqdm_notebook if in_notebook() else tqdm_terminal
+ return verbose
diff --git a/netdissect/runningstats.py b/netdissect/runningstats.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4093e0318edeecf8aebc34771adbde5043e2d4
--- /dev/null
+++ b/netdissect/runningstats.py
@@ -0,0 +1,773 @@
+'''
+Running statistics on the GPU using pytorch.
+
+RunningTopK maintains top-k statistics for a set of channels in parallel.
+RunningQuantile maintains (sampled) quantile statistics for a set of channels.
+'''
+
+import torch, math, numpy
+from collections import defaultdict
+
+class RunningTopK:
+ '''
+ A class to keep a running tally of the the top k values (and indexes)
+ of any number of torch feature components. Will work on the GPU if
+ the data is on the GPU.
+
+ This version flattens all arrays to avoid crashes.
+ '''
+ def __init__(self, k=100, state=None):
+ if state is not None:
+ self.set_state_dict(state)
+ return
+ self.k = k
+ self.count = 0
+ # This version flattens all data internally to 2-d tensors,
+ # to avoid crashes with the current pytorch topk implementation.
+ # The data is puffed back out to arbitrary tensor shapes on ouput.
+ self.data_shape = None
+ self.top_data = None
+ self.top_index = None
+ self.next = 0
+ self.linear_index = 0
+ self.perm = None
+
+ def add(self, data):
+ '''
+ Adds a batch of data to be considered for the running top k.
+ The zeroth dimension enumerates the observations. All other
+ dimensions enumerate different features.
+ '''
+ if self.top_data is None:
+ # Allocation: allocate a buffer of size 5*k, at least 10, for each.
+ self.data_shape = data.shape[1:]
+ feature_size = int(numpy.prod(self.data_shape))
+ self.top_data = torch.zeros(
+ feature_size, max(10, self.k * 5), out=data.new())
+ self.top_index = self.top_data.clone().long()
+ self.linear_index = 0 if len(data.shape) == 1 else torch.arange(
+ feature_size, out=self.top_index.new()).mul_(
+ self.top_data.shape[-1])[:,None]
+ size = data.shape[0]
+ sk = min(size, self.k)
+ if self.top_data.shape[-1] < self.next + sk:
+ # Compression: if full, keep topk only.
+ self.top_data[:,:self.k], self.top_index[:,:self.k] = (
+ self.result(sorted=False, flat=True))
+ self.next = self.k
+ free = self.top_data.shape[-1] - self.next
+ # Pick: copy the top sk of the next batch into the buffer.
+ # Currently strided topk is slow. So we clone after transpose.
+ # TODO: remove the clone() if it becomes faster.
+ cdata = data.contiguous().view(size, -1).t().clone()
+ td, ti = cdata.topk(sk, sorted=False)
+ self.top_data[:,self.next:self.next+sk] = td
+ self.top_index[:,self.next:self.next+sk] = (ti + self.count)
+ self.next += sk
+ self.count += size
+
+ def result(self, sorted=True, flat=False):
+ '''
+ Returns top k data items and indexes in each dimension,
+ with channels in the first dimension and k in the last dimension.
+ '''
+ k = min(self.k, self.next)
+ # bti are top indexes relative to buffer array.
+ td, bti = self.top_data[:,:self.next].topk(k, sorted=sorted)
+ # we want to report top indexes globally, which is ti.
+ ti = self.top_index.view(-1)[
+ (bti + self.linear_index).view(-1)
+ ].view(*bti.shape)
+ if flat:
+ return td, ti
+ else:
+ return (td.view(*(self.data_shape + (-1,))),
+ ti.view(*(self.data_shape + (-1,))))
+
+ def to_(self, device):
+ self.top_data = self.top_data.to(device)
+ self.top_index = self.top_index.to(device)
+ if isinstance(self.linear_index, torch.Tensor):
+ self.linear_index = self.linear_index.to(device)
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + '.' +
+ self.__class__.__name__ + '()',
+ k=self.k,
+ count=self.count,
+ data_shape=tuple(self.data_shape),
+ top_data=self.top_data.cpu().numpy(),
+ top_index=self.top_index.cpu().numpy(),
+ next=self.next,
+ linear_index=(self.linear_index.cpu().numpy()
+ if isinstance(self.linear_index, torch.Tensor)
+ else self.linear_index),
+ perm=self.perm)
+
+ def set_state_dict(self, dic):
+ self.k = dic['k'].item()
+ self.count = dic['count'].item()
+ self.data_shape = tuple(dic['data_shape'])
+ self.top_data = torch.from_numpy(dic['top_data'])
+ self.top_index = torch.from_numpy(dic['top_index'])
+ self.next = dic['next'].item()
+ self.linear_index = (torch.from_numpy(dic['linear_index'])
+ if len(dic['linear_index'].shape) > 0
+ else dic['linear_index'].item())
+
+class RunningQuantile:
+ """
+ Streaming randomized quantile computation for torch.
+
+ Add any amount of data repeatedly via add(data). At any time,
+ quantile estimates (or old-style percentiles) can be read out using
+ quantiles(q) or percentiles(p).
+
+ Accuracy scales according to resolution: the default is to
+ set resolution to be accurate to better than 0.1%,
+ while limiting storage to about 50,000 samples.
+
+ Good for computing quantiles of huge data without using much memory.
+ Works well on arbitrary data with probability near 1.
+
+ Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty
+ from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf
+ """
+
+ def __init__(self, resolution=6 * 1024, buffersize=None, seed=None,
+ state=None):
+ if state is not None:
+ self.set_state_dict(state)
+ return
+ self.depth = None
+ self.dtype = None
+ self.device = None
+ self.resolution = resolution
+ # Default buffersize: 128 samples (and smaller than resolution).
+ if buffersize is None:
+ buffersize = min(128, (resolution + 7) // 8)
+ self.buffersize = buffersize
+ self.samplerate = 1.0
+ self.data = None
+ self.firstfree = [0]
+ self.randbits = torch.ByteTensor(resolution)
+ self.currentbit = len(self.randbits) - 1
+ self.extremes = None
+ self.size = 0
+
+ def _lazy_init(self, incoming):
+ self.depth = incoming.shape[1]
+ self.dtype = incoming.dtype
+ self.device = incoming.device
+ self.data = [torch.zeros(self.depth, self.resolution,
+ dtype=self.dtype, device=self.device)]
+ self.extremes = torch.zeros(self.depth, 2,
+ dtype=self.dtype, device=self.device)
+ self.extremes[:,0] = float('inf')
+ self.extremes[:,-1] = -float('inf')
+
+ def to_(self, device):
+ """Switches internal storage to specified device."""
+ if device != self.device:
+ old_data = self.data
+ old_extremes = self.extremes
+ self.data = [d.to(device) for d in self.data]
+ self.extremes = self.extremes.to(device)
+ self.device = self.extremes.device
+ del old_data
+ del old_extremes
+
+ def add(self, incoming):
+ if self.depth is None:
+ self._lazy_init(incoming)
+ assert len(incoming.shape) == 2
+ assert incoming.shape[1] == self.depth, (incoming.shape[1], self.depth)
+ self.size += incoming.shape[0]
+ # Convert to a flat torch array.
+ if self.samplerate >= 1.0:
+ self._add_every(incoming)
+ return
+ # If we are sampling, then subsample a large chunk at a time.
+ self._scan_extremes(incoming)
+ chunksize = int(math.ceil(self.buffersize / self.samplerate))
+ for index in range(0, len(incoming), chunksize):
+ batch = incoming[index:index+chunksize]
+ sample = sample_portion(batch, self.samplerate)
+ if len(sample):
+ self._add_every(sample)
+
+ def _add_every(self, incoming):
+ supplied = len(incoming)
+ index = 0
+ while index < supplied:
+ ff = self.firstfree[0]
+ available = self.data[0].shape[1] - ff
+ if available == 0:
+ if not self._shift():
+ # If we shifted by subsampling, then subsample.
+ incoming = incoming[index:]
+ if self.samplerate >= 0.5:
+ # First time sampling - the data source is very large.
+ self._scan_extremes(incoming)
+ incoming = sample_portion(incoming, self.samplerate)
+ index = 0
+ supplied = len(incoming)
+ ff = self.firstfree[0]
+ available = self.data[0].shape[1] - ff
+ copycount = min(available, supplied - index)
+ self.data[0][:,ff:ff + copycount] = torch.t(
+ incoming[index:index + copycount,:])
+ self.firstfree[0] += copycount
+ index += copycount
+
+ def _shift(self):
+ index = 0
+ # If remaining space at the current layer is less than half prev
+ # buffer size (rounding up), then we need to shift it up to ensure
+ # enough space for future shifting.
+ while self.data[index].shape[1] - self.firstfree[index] < (
+ -(-self.data[index-1].shape[1] // 2) if index else 1):
+ if index + 1 >= len(self.data):
+ return self._expand()
+ data = self.data[index][:,0:self.firstfree[index]]
+ data = data.sort()[0]
+ if index == 0 and self.samplerate >= 1.0:
+ self._update_extremes(data[:,0], data[:,-1])
+ offset = self._randbit()
+ position = self.firstfree[index + 1]
+ subset = data[:,offset::2]
+ self.data[index + 1][:,position:position + subset.shape[1]] = subset
+ self.firstfree[index] = 0
+ self.firstfree[index + 1] += subset.shape[1]
+ index += 1
+ return True
+
+ def _scan_extremes(self, incoming):
+ # When sampling, we need to scan every item still to get extremes
+ self._update_extremes(
+ torch.min(incoming, dim=0)[0],
+ torch.max(incoming, dim=0)[0])
+
+ def _update_extremes(self, minr, maxr):
+ self.extremes[:,0] = torch.min(
+ torch.stack([self.extremes[:,0], minr]), dim=0)[0]
+ self.extremes[:,-1] = torch.max(
+ torch.stack([self.extremes[:,-1], maxr]), dim=0)[0]
+
+ def _randbit(self):
+ self.currentbit += 1
+ if self.currentbit >= len(self.randbits):
+ self.randbits.random_(to=2)
+ self.currentbit = 0
+ return self.randbits[self.currentbit]
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + '.' +
+ self.__class__.__name__ + '()',
+ resolution=self.resolution,
+ depth=self.depth,
+ buffersize=self.buffersize,
+ samplerate=self.samplerate,
+ data=[d.cpu().numpy()[:,:f].T
+ for d, f in zip(self.data, self.firstfree)],
+ sizes=[d.shape[1] for d in self.data],
+ extremes=self.extremes.cpu().numpy(),
+ size=self.size)
+
+ def set_state_dict(self, dic):
+ self.resolution = int(dic['resolution'])
+ self.randbits = torch.ByteTensor(self.resolution)
+ self.currentbit = len(self.randbits) - 1
+ self.depth = int(dic['depth'])
+ self.buffersize = int(dic['buffersize'])
+ self.samplerate = float(dic['samplerate'])
+ firstfree = []
+ buffers = []
+ for d, s in zip(dic['data'], dic['sizes']):
+ firstfree.append(d.shape[0])
+ buf = numpy.zeros((d.shape[1], s), dtype=d.dtype)
+ buf[:,:d.shape[0]] = d.T
+ buffers.append(torch.from_numpy(buf))
+ self.firstfree = firstfree
+ self.data = buffers
+ self.extremes = torch.from_numpy((dic['extremes']))
+ self.size = int(dic['size'])
+ self.dtype = self.extremes.dtype
+ self.device = self.extremes.device
+
+ def minmax(self):
+ if self.firstfree[0]:
+ self._scan_extremes(self.data[0][:,:self.firstfree[0]].t())
+ return self.extremes.clone()
+
+ def median(self):
+ return self.quantiles([0.5])[:,0]
+
+ def mean(self):
+ return self.integrate(lambda x: x) / self.size
+
+ def variance(self):
+ mean = self.mean()[:,None]
+ return self.integrate(lambda x: (x - mean).pow(2)) / (self.size - 1)
+
+ def stdev(self):
+ return self.variance().sqrt()
+
+ def _expand(self):
+ cap = self._next_capacity()
+ if cap > 0:
+ # First, make a new layer of the proper capacity.
+ self.data.insert(0, torch.zeros(self.depth, cap,
+ dtype=self.dtype, device=self.device))
+ self.firstfree.insert(0, 0)
+ else:
+ # Unless we're so big we are just subsampling.
+ assert self.firstfree[0] == 0
+ self.samplerate *= 0.5
+ for index in range(1, len(self.data)):
+ # Scan for existing data that needs to be moved down a level.
+ amount = self.firstfree[index]
+ if amount == 0:
+ continue
+ position = self.firstfree[index-1]
+ # Move data down if it would leave enough empty space there
+ # This is the key invariant: enough empty space to fit half
+ # of the previous level's buffer size (rounding up)
+ if self.data[index-1].shape[1] - (amount + position) >= (
+ -(-self.data[index-2].shape[1] // 2) if (index-1) else 1):
+ self.data[index-1][:,position:position + amount] = (
+ self.data[index][:,:amount])
+ self.firstfree[index-1] += amount
+ self.firstfree[index] = 0
+ else:
+ # Scrunch the data if it would not.
+ data = self.data[index][:,:amount]
+ data = data.sort()[0]
+ if index == 1:
+ self._update_extremes(data[:,0], data[:,-1])
+ offset = self._randbit()
+ scrunched = data[:,offset::2]
+ self.data[index][:,:scrunched.shape[1]] = scrunched
+ self.firstfree[index] = scrunched.shape[1]
+ return cap > 0
+
+ def _next_capacity(self):
+ cap = int(math.ceil(self.resolution * (0.67 ** len(self.data))))
+ if cap < 2:
+ return 0
+ # Round up to the nearest multiple of 8 for better GPU alignment.
+ cap = -8 * (-cap // 8)
+ return max(self.buffersize, cap)
+
+ def _weighted_summary(self, sort=True):
+ if self.firstfree[0]:
+ self._scan_extremes(self.data[0][:,:self.firstfree[0]].t())
+ size = sum(self.firstfree) + 2
+ weights = torch.FloatTensor(size) # Floating point
+ summary = torch.zeros(self.depth, size,
+ dtype=self.dtype, device=self.device)
+ weights[0:2] = 0
+ summary[:,0:2] = self.extremes
+ index = 2
+ for level, ff in enumerate(self.firstfree):
+ if ff == 0:
+ continue
+ summary[:,index:index + ff] = self.data[level][:,:ff]
+ weights[index:index + ff] = 2.0 ** level
+ index += ff
+ assert index == summary.shape[1]
+ if sort:
+ summary, order = torch.sort(summary, dim=-1)
+ weights = weights[order.view(-1).cpu()].view(order.shape)
+ return (summary, weights)
+
+ def quantiles(self, quantiles, old_style=False):
+ if self.size == 0:
+ return torch.full((self.depth, len(quantiles)), torch.nan)
+ summary, weights = self._weighted_summary()
+ cumweights = torch.cumsum(weights, dim=-1) - weights / 2
+ if old_style:
+ # To be convenient with torch.percentile
+ cumweights -= cumweights[:,0:1].clone()
+ cumweights /= cumweights[:,-1:].clone()
+ else:
+ cumweights /= torch.sum(weights, dim=-1, keepdim=True)
+ result = torch.zeros(self.depth, len(quantiles),
+ dtype=self.dtype, device=self.device)
+ # numpy is needed for interpolation
+ if not hasattr(quantiles, 'cpu'):
+ quantiles = torch.Tensor(quantiles)
+ nq = quantiles.cpu().numpy()
+ ncw = cumweights.cpu().numpy()
+ nsm = summary.cpu().numpy()
+ for d in range(self.depth):
+ result[d] = torch.tensor(numpy.interp(nq, ncw[d], nsm[d]),
+ dtype=self.dtype, device=self.device)
+ return result
+
+ def integrate(self, fun):
+ result = None
+ for level, ff in enumerate(self.firstfree):
+ if ff == 0:
+ continue
+ term = torch.sum(
+ fun(self.data[level][:,:ff]) * (2.0 ** level),
+ dim=-1)
+ if result is None:
+ result = term
+ else:
+ result += term
+ if result is not None:
+ result /= self.samplerate
+ return result
+
+ def percentiles(self, percentiles):
+ return self.quantiles(percentiles, old_style=True)
+
+ def readout(self, count=1001, old_style=True):
+ return self.quantiles(
+ torch.linspace(0.0, 1.0, count), old_style=old_style)
+
+ def normalize(self, data):
+ '''
+ Given input data as taken from the training distirbution,
+ normalizes every channel to reflect quantile values,
+ uniformly distributed, within [0, 1].
+ '''
+ assert self.size > 0
+ assert data.shape[0] == self.depth
+ summary, weights = self._weighted_summary()
+ cumweights = torch.cumsum(weights, dim=-1) - weights / 2
+ cumweights /= torch.sum(weights, dim=-1, keepdim=True)
+ result = torch.zeros_like(data).float()
+ # numpy is needed for interpolation
+ ndata = data.cpu().numpy().reshape((data.shape[0], -1))
+ ncw = cumweights.cpu().numpy()
+ nsm = summary.cpu().numpy()
+ for d in range(self.depth):
+ normed = torch.tensor(numpy.interp(ndata[d], nsm[d], ncw[d]),
+ dtype=torch.float, device=data.device).clamp_(0.0, 1.0)
+ if len(data.shape) > 1:
+ normed = normed.view(*(data.shape[1:]))
+ result[d] = normed
+ return result
+
+
+class RunningConditionalQuantile:
+ '''
+ Equivalent to a map from conditions (any python hashable type)
+ to RunningQuantiles. The reason for the type is to allow limited
+ GPU memory to be exploited while counting quantile stats on many
+ different conditions, a few of which are common and which benefit
+ from GPU, but most of which are rare and would not all fit into
+ GPU RAM.
+
+ To move a set of conditions to a device, use rcq.to_(device, conds).
+ Then in the future, move the tallied data to the device before
+ calling rcq.add, that is, rcq.add(cond, data.to(device)).
+
+ To allow the caller to decide which conditions to allow to use GPU,
+ rcq.most_common_conditions(n) returns a list of the n most commonly
+ added conditions so far.
+ '''
+ def __init__(self, resolution=6 * 1024, buffersize=None, seed=None,
+ state=None):
+ self.first_rq = None
+ self.call_stats = defaultdict(int)
+ self.running_quantiles = {}
+ if state is not None:
+ self.set_state_dict(state)
+ return
+ self.rq_args = dict(resolution=resolution, buffersize=buffersize,
+ seed=seed)
+
+ def add(self, condition, incoming):
+ if condition not in self.running_quantiles:
+ self.running_quantiles[condition] = RunningQuantile(**self.rq_args)
+ if self.first_rq is None:
+ self.first_rq = self.running_quantiles[condition]
+ self.call_stats[condition] += 1
+ rq = self.running_quantiles[condition]
+ # For performance reasons, the caller can move some conditions to
+ # the CPU if they are not among the most common conditions.
+ if rq.device is not None and (rq.device != incoming.device):
+ rq.to_(incoming.device)
+ self.running_quantiles[condition].add(incoming)
+
+ def most_common_conditions(self, n):
+ return sorted(self.call_stats.keys(),
+ key=lambda c: -self.call_stats[c])[:n]
+
+ def collected_add(self, conditions, incoming):
+ for c in conditions:
+ self.add(c, incoming)
+
+ def conditional(self, c):
+ return self.running_quantiles[c]
+
+ def collected_quantiles(self, conditions, quantiles, old_style=False):
+ result = torch.zeros(
+ size=(len(conditions), self.first_rq.depth, len(quantiles)),
+ dtype=self.first_rq.dtype,
+ device=self.first_rq.device)
+ for i, c in enumerate(conditions):
+ if c in self.running_quantiles:
+ result[i] = self.running_quantiles[c].quantiles(
+ quantiles, old_style)
+ return result
+
+ def collected_normalize(self, conditions, values):
+ result = torch.zeros(
+ size=(len(conditions), values.shape[0], values.shape[1]),
+ dtype=torch.float,
+ device=self.first_rq.device)
+ for i, c in enumerate(conditions):
+ if c in self.running_quantiles:
+ result[i] = self.running_quantiles[c].normalize(values)
+ return result
+
+ def to_(self, device, conditions=None):
+ if conditions is None:
+ conditions = self.running_quantiles.keys()
+ for cond in conditions:
+ if cond in self.running_quantiles:
+ self.running_quantiles[cond].to_(device)
+
+ def state_dict(self):
+ conditions = sorted(self.running_quantiles.keys())
+ result = dict(
+ constructor=self.__module__ + '.' +
+ self.__class__.__name__ + '()',
+ rq_args=self.rq_args,
+ conditions=conditions)
+ for i, c in enumerate(conditions):
+ result.update({
+ '%d.%s' % (i, k): v
+ for k, v in self.running_quantiles[c].state_dict().items()})
+ return result
+
+ def set_state_dict(self, dic):
+ self.rq_args = dic['rq_args'].item()
+ conditions = list(dic['conditions'])
+ subdicts = defaultdict(dict)
+ for k, v in dic.items():
+ if '.' in k:
+ p, s = k.split('.', 1)
+ subdicts[p][s] = v
+ self.running_quantiles = {
+ c: RunningQuantile(state=subdicts[str(i)])
+ for i, c in enumerate(conditions)}
+ if conditions:
+ self.first_rq = self.running_quantiles[conditions[0]]
+
+ # example usage:
+ # levels = rqc.conditional(()).quantiles(1 - fracs)
+ # denoms = 1 - rqc.collected_normalize(cats, levels)
+ # isects = 1 - rqc.collected_normalize(labels, levels)
+ # unions = fracs + denoms[cats] - isects
+ # iou = isects / unions
+
+
+
+
+class RunningCrossCovariance:
+ '''
+ Running computation. Use this when an off-diagonal block of the
+ covariance matrix is needed (e.g., when the whole covariance matrix
+ does not fit in the GPU).
+
+ Chan-style numerically stable update of mean and full covariance matrix.
+ Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386
+ '''
+ def __init__(self, state=None):
+ if state is not None:
+ self.set_state_dict(state)
+ return
+ self.count = 0
+ self._mean = None
+ self.cmom2 = None
+ self.v_cmom2 = None
+
+ def add(self, a, b):
+ if len(a.shape) == 1:
+ a = a[None, :]
+ b = b[None, :]
+ assert(a.shape[0] == b.shape[0])
+ if len(a.shape) > 2:
+ a, b = [d.view(d.shape[0], d.shape[1], -1).permute(0, 2, 1
+ ).contiguous().view(-1, d.shape[1]) for d in [a, b]]
+ batch_count = a.shape[0]
+ batch_mean = [d.sum(0) / batch_count for d in [a, b]]
+ centered = [d - bm for d, bm in zip([a, b], batch_mean)]
+ # If more than 10 billion operations, divide into batches.
+ sub_batch = -(-(10 << 30) // (a.shape[1] * b.shape[1]))
+ # Initial batch.
+ if self._mean is None:
+ self.count = batch_count
+ self._mean = batch_mean
+ self.v_cmom2 = [c.pow(2).sum(0) for c in centered]
+ self.cmom2 = a.new(a.shape[1], b.shape[1]).zero_()
+ progress_addbmm(self.cmom2, centered[0][:,:,None],
+ centered[1][:,None,:], sub_batch)
+ return
+ # Update a batch using Chan-style update for numerical stability.
+ oldcount = self.count
+ self.count += batch_count
+ new_frac = float(batch_count) / self.count
+ # Update the mean according to the batch deviation from the old mean.
+ delta = [bm.sub_(m).mul_(new_frac)
+ for bm, m in zip(batch_mean, self._mean)]
+ for m, d in zip(self._mean, delta):
+ m.add_(d)
+ # Update the cross-covariance using the batch deviation
+ progress_addbmm(self.cmom2, centered[0][:,:,None],
+ centered[1][:,None,:], sub_batch)
+ self.cmom2.addmm_(alpha=new_frac * oldcount,
+ mat1=delta[0][:,None], mat2=delta[1][None,:])
+ # Update the variance using the batch deviation
+ for c, vc2, d in zip(centered, self.v_cmom2, delta):
+ vc2.add_(c.pow(2).sum(0))
+ vc2.add_(d.pow_(2).mul_(new_frac * oldcount))
+
+ def mean(self):
+ return self._mean
+
+ def variance(self):
+ return [vc2 / (self.count - 1) for vc2 in self.v_cmom2]
+
+ def stdev(self):
+ return [v.sqrt() for v in self.variance()]
+
+ def covariance(self):
+ return self.cmom2 / (self.count - 1)
+
+ def correlation(self):
+ covariance = self.covariance()
+ rstdev = [s.reciprocal() for s in self.stdev()]
+ cor = rstdev[0][:,None] * covariance * rstdev[1][None,:]
+ # Remove NaNs
+ cor[torch.isnan(cor)] = 0
+ return cor
+
+ def to_(self, device):
+ self._mean = [m.to(device) for m in self._mean]
+ self.v_cmom2 = [vcs.to(device) for vcs in self.v_cmom2]
+ self.cmom2 = self.cmom2.to(device)
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + '.' +
+ self.__class__.__name__ + '()',
+ count=self.count,
+ mean_a=self._mean[0].cpu().numpy(),
+ mean_b=self._mean[1].cpu().numpy(),
+ cmom2_a=self.v_cmom2[0].cpu().numpy(),
+ cmom2_b=self.v_cmom2[1].cpu().numpy(),
+ cmom2=self.cmom2.cpu().numpy())
+
+ def set_state_dict(self, dic):
+ self.count = dic['count'].item()
+ self._mean = [torch.from_numpy(dic[k]) for k in ['mean_a', 'mean_b']]
+ self.v_cmom2 = [torch.from_numpy(dic[k])
+ for k in ['cmom2_a', 'cmom2_b']]
+ self.cmom2 = torch.from_numpy(dic['cmom2'])
+
+def progress_addbmm(accum, x, y, batch_size):
+ '''
+ Break up very large adbmm operations into batches so progress can be seen.
+ '''
+ from .progress import default_progress
+ if x.shape[0] <= batch_size:
+ return accum.addbmm_(x, y)
+ progress = default_progress(None)
+ for i in progress(range(0, x.shape[0], batch_size), desc='bmm'):
+ accum.addbmm_(x[i:i+batch_size], y[i:i+batch_size])
+ return accum
+
+
+def sample_portion(vec, p=0.5):
+ bits = torch.bernoulli(torch.zeros(vec.shape[0], dtype=torch.uint8,
+ device=vec.device), p)
+ return vec[bits]
+
+if __name__ == '__main__':
+ import warnings
+ warnings.filterwarnings("error")
+ import time
+ import argparse
+ parser = argparse.ArgumentParser(
+ description='Test things out')
+ parser.add_argument('--mode', default='cpu', help='cpu or cuda')
+ parser.add_argument('--test_size', type=int, default=1000000)
+ args = parser.parse_args()
+
+ # An adverarial case: we keep finding more numbers in the middle
+ # as the stream goes on.
+ amount = args.test_size
+ quantiles = 1000
+ data = numpy.arange(float(amount))
+ data[1::2] = data[-1::-2] + (len(data) - 1)
+ data /= 2
+ depth = 50
+ test_cuda = torch.cuda.is_available()
+ alldata = data[:,None] + (numpy.arange(depth) * amount)[None, :]
+ actual_sum = torch.FloatTensor(numpy.sum(alldata * alldata, axis=0))
+ amt = amount // depth
+ for r in range(depth):
+ numpy.random.shuffle(alldata[r*amt:r*amt+amt,r])
+ if args.mode == 'cuda':
+ alldata = torch.cuda.FloatTensor(alldata)
+ dtype = torch.float
+ device = torch.device('cuda')
+ else:
+ alldata = torch.FloatTensor(alldata)
+ dtype = torch.float
+ device = None
+ starttime = time.time()
+ qc = RunningQuantile(resolution=6 * 1024)
+ qc.add(alldata)
+ # Test state dict
+ saved = qc.state_dict()
+ # numpy.savez('foo.npz', **saved)
+ # saved = numpy.load('foo.npz')
+ qc = RunningQuantile(state=saved)
+ assert not qc.device.type == 'cuda'
+ qc.add(alldata)
+ actual_sum *= 2
+ ro = qc.readout(1001).cpu()
+ endtime = time.time()
+ gt = torch.linspace(0, amount, quantiles+1)[None,:] + (
+ torch.arange(qc.depth, dtype=torch.float) * amount)[:,None]
+ maxreldev = torch.max(torch.abs(ro - gt) / amount) * quantiles
+ print("Maximum relative deviation among %d perentiles: %f" % (
+ quantiles, maxreldev))
+ minerr = torch.max(torch.abs(qc.minmax().cpu()[:,0] -
+ torch.arange(qc.depth, dtype=torch.float) * amount))
+ maxerr = torch.max(torch.abs((qc.minmax().cpu()[:, -1] + 1) -
+ (torch.arange(qc.depth, dtype=torch.float) + 1) * amount))
+ print("Minmax error %f, %f" % (minerr, maxerr))
+ interr = torch.max(torch.abs(qc.integrate(lambda x: x * x).cpu()
+ - actual_sum) / actual_sum)
+ print("Integral error: %f" % interr)
+ medianerr = torch.max(torch.abs(qc.median() -
+ alldata.median(0)[0]) / alldata.median(0)[0]).cpu()
+ print("Median error: %f" % interr)
+ meanerr = torch.max(
+ torch.abs(qc.mean() - alldata.mean(0)) / alldata.mean(0)).cpu()
+ print("Mean error: %f" % meanerr)
+ varerr = torch.max(
+ torch.abs(qc.variance() - alldata.var(0)) / alldata.var(0)).cpu()
+ print("Variance error: %f" % varerr)
+ counterr = ((qc.integrate(lambda x: torch.ones(x.shape[-1]).cpu())
+ - qc.size) / (0.0 + qc.size)).item()
+ print("Count error: %f" % counterr)
+ print("Time %f" % (endtime - starttime))
+ # Algorithm is randomized, so some of these will fail with low probability.
+ assert maxreldev < 1.0
+ assert minerr == 0.0
+ assert maxerr == 0.0
+ assert interr < 0.01
+ assert abs(counterr) < 0.001
+ print("OK")
diff --git a/netdissect/sampler.py b/netdissect/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..72f1b46da117403c7f6ddcc1877bd9d70ded962b
--- /dev/null
+++ b/netdissect/sampler.py
@@ -0,0 +1,134 @@
+'''
+A sampler is just a list of integer listing the indexes of the
+inputs in a data set to sample. For reproducibility, the
+FixedRandomSubsetSampler uses a seeded prng to produce the same
+sequence always. FixedSubsetSampler is just a wrapper for an
+explicit list of integers.
+
+coordinate_sample solves another sampling problem: when testing
+convolutional outputs, we can reduce data explosing by sampling
+random points of the feature map rather than the entire feature map.
+coordinate_sample does this in a deterministic way that is also
+resolution-independent.
+'''
+
+import numpy
+import random
+from torch.utils.data.sampler import Sampler
+
+class FixedSubsetSampler(Sampler):
+ """Represents a fixed sequence of data set indices.
+ Subsets can be created by specifying a subset of output indexes.
+ """
+ def __init__(self, samples):
+ self.samples = samples
+
+ def __iter__(self):
+ return iter(self.samples)
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, key):
+ return self.samples[key]
+
+ def subset(self, new_subset):
+ return FixedSubsetSampler(self.dereference(new_subset))
+
+ def dereference(self, indices):
+ '''
+ Translate output sample indices (small numbers indexing the sample)
+ to input sample indices (larger number indexing the original full set)
+ '''
+ return [self.samples[i] for i in indices]
+
+
+class FixedRandomSubsetSampler(FixedSubsetSampler):
+ """Samples a fixed number of samples from the dataset, deterministically.
+ Arguments:
+ data_source,
+ sample_size,
+ seed (optional)
+ """
+ def __init__(self, data_source, start=None, end=None, seed=1):
+ rng = random.Random(seed)
+ shuffled = list(range(len(data_source)))
+ rng.shuffle(shuffled)
+ self.data_source = data_source
+ super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end])
+
+ def class_subset(self, class_filter):
+ '''
+ Returns only the subset matching the given rule.
+ '''
+ if isinstance(class_filter, int):
+ rule = lambda d: d[1] == class_filter
+ else:
+ rule = class_filter
+ return self.subset([i for i, j in enumerate(self.samples)
+ if rule(self.data_source[j])])
+
+def coordinate_sample(shape, sample_size, seeds, grid=13, seed=1, flat=False):
+ '''
+ Returns a (end-start) sets of sample_size grid points within
+ the shape given. If the shape dimensions are a multiple of 'grid',
+ then sampled points within the same row will never be duplicated.
+ '''
+ if flat:
+ sampind = numpy.zeros((len(seeds), sample_size), dtype=int)
+ else:
+ sampind = numpy.zeros((len(seeds), 2, sample_size), dtype=int)
+ assert sample_size <= grid
+ for j, seed in enumerate(seeds):
+ rng = numpy.random.RandomState(seed)
+ # Shuffle the 169 random grid squares, and pick :sample_size.
+ square_count = grid ** len(shape)
+ square = numpy.stack(numpy.unravel_index(
+ rng.choice(square_count, square_count)[:sample_size],
+ (grid,) * len(shape)))
+ # Then add a random offset to each x, y and put in the range [0...1)
+ # Notice this selects the same locations regardless of resolution.
+ uniform = (square + rng.uniform(size=square.shape)) / grid
+ # TODO: support affine scaling so that we can align receptive field
+ # centers exactly when sampling neurons in different layers.
+ coords = (uniform * numpy.array(shape)[:,None]).astype(int)
+ # Now take sample_size without replacement. We do this in a way
+ # such that if sample_size is decreased or increased up to 'grid',
+ # the selected points become a subset, not totally different points.
+ if flat:
+ sampind[j] = numpy.ravel_multi_index(coords, dims=shape)
+ else:
+ sampind[j] = coords
+ return sampind
+
+if __name__ == '__main__':
+ from numpy.testing import assert_almost_equal
+ # Test that coordinate_sample is deterministic, in-range, and scalable.
+ assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102)),
+ [[[14, 0, 12, 11, 8, 13, 11, 20, 7, 20],
+ [ 9, 22, 7, 11, 23, 18, 21, 15, 2, 5]]])
+ assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 102)),
+ [[[ 7, 0, 6, 5, 4, 6, 5, 10, 3, 20 // 2],
+ [ 4, 11, 3, 5, 11, 9, 10, 7, 1, 5 // 2]]])
+ assert_almost_equal(coordinate_sample((13, 13), 10, range(100, 102),
+ flat=True),
+ [[ 8, 24, 67, 103, 87, 79, 138, 94, 98, 53],
+ [ 95, 11, 81, 70, 63, 87, 75, 137, 40, 2+10*13]])
+ assert_almost_equal(coordinate_sample((13, 13), 10, range(101, 103),
+ flat=True),
+ [[ 95, 11, 81, 70, 63, 87, 75, 137, 40, 132],
+ [ 0, 78, 114, 111, 66, 45, 72, 73, 79, 135]])
+ assert_almost_equal(coordinate_sample((26, 26), 10, range(101, 102),
+ flat=True),
+ [[373, 22, 319, 297, 231, 356, 307, 535, 184, 5+20*26]])
+ # Test FixedRandomSubsetSampler
+ fss = FixedRandomSubsetSampler(range(10))
+ assert len(fss) == 10
+ assert_almost_equal(list(fss), [8, 0, 3, 4, 5, 2, 9, 6, 7, 1])
+ fss = FixedRandomSubsetSampler(range(10), 3, 8)
+ assert len(fss) == 5
+ assert_almost_equal(list(fss), [4, 5, 2, 9, 6])
+ fss = FixedRandomSubsetSampler([(i, i % 3) for i in range(10)],
+ class_filter=1)
+ assert len(fss) == 3
+ assert_almost_equal(list(fss), [4, 7, 1])
diff --git a/netdissect/segdata.py b/netdissect/segdata.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3cb6dfac8985d9c55344abbc26cc26c4862aa85
--- /dev/null
+++ b/netdissect/segdata.py
@@ -0,0 +1,74 @@
+import os, numpy, torch, json
+from .parallelfolder import ParallelImageFolders
+from torchvision import transforms
+from torchvision.transforms.functional import to_tensor, normalize
+
+class FieldDef(object):
+ def __init__(self, field, index, bitshift, bitmask, labels):
+ self.field = field
+ self.index = index
+ self.bitshift = bitshift
+ self.bitmask = bitmask
+ self.labels = labels
+
+class MultiSegmentDataset(object):
+ '''
+ Just like ClevrMulticlassDataset, but the second stream is a one-hot
+ segmentation tensor rather than a flat one-hot presence vector.
+
+ MultiSegmentDataset('dataset/clevrseg',
+ imgdir='images/train/positive',
+ segdir='images/train/segmentation')
+ '''
+ def __init__(self, directory, transform=None,
+ imgdir='img', segdir='seg', val=False, size=None):
+ self.segdataset = ParallelImageFolders(
+ [os.path.join(directory, imgdir),
+ os.path.join(directory, segdir)],
+ transform=transform)
+ self.fields = []
+ with open(os.path.join(directory, 'labelnames.json'), 'r') as f:
+ for defn in json.load(f):
+ self.fields.append(FieldDef(
+ defn['field'], defn['index'], defn['bitshift'],
+ defn['bitmask'], defn['label']))
+ self.labels = ['-'] # Reserve label 0 to mean "no label"
+ self.categories = []
+ self.label_category = [0]
+ for fieldnum, f in enumerate(self.fields):
+ self.categories.append(f.field)
+ f.firstchannel = len(self.labels)
+ f.channels = len(f.labels) - 1
+ for lab in f.labels[1:]:
+ self.labels.append(lab)
+ self.label_category.append(fieldnum)
+ # Reserve 25% of the dataset for validation.
+ first_val = int(len(self.segdataset) * 0.75)
+ self.val = val
+ self.first = first_val if val else 0
+ self.length = len(self.segdataset) - first_val if val else first_val
+ # Truncate the dataset if requested.
+ if size:
+ self.length = min(size, self.length)
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, index):
+ img, segimg = self.segdataset[index + self.first]
+ segin = numpy.array(segimg, numpy.uint8, copy=False)
+ segout = torch.zeros(len(self.categories),
+ segin.shape[0], segin.shape[1], dtype=torch.int64)
+ for i, field in enumerate(self.fields):
+ fielddata = ((torch.from_numpy(segin[:, :, field.index])
+ >> field.bitshift) & field.bitmask)
+ segout[i] = field.firstchannel + fielddata - 1
+ bincount = numpy.bincount(segout.flatten(),
+ minlength=len(self.labels))
+ return img, segout, bincount
+
+if __name__ == '__main__':
+ ds = MultiSegmentDataset('dataset/clevrseg')
+ print(ds[0])
+ import pdb; pdb.set_trace()
+
diff --git a/netdissect/segmenter.py b/netdissect/segmenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5ebe364bc30f32581f0d560e11f08bfbd0d1731
--- /dev/null
+++ b/netdissect/segmenter.py
@@ -0,0 +1,581 @@
+# Usage as a simple differentiable segmenter base class
+
+import os, torch, numpy, json, glob
+import skimage.morphology
+from collections import OrderedDict
+from netdissect import upsegmodel
+from netdissect import segmodel as segmodel_module
+from netdissect.easydict import EasyDict
+from urllib.request import urlretrieve
+
+class BaseSegmenter:
+ def get_label_and_category_names(self):
+ '''
+ Returns two lists: first, a list of tuples [(label, category), ...]
+ where the label and category are human-readable strings indicating
+ the meaning of a segmentation class. The 0th segmentation class
+ should be reserved for a label ('-') that means "no prediction."
+ The second list should just be a list of [category,...] listing
+ all categories in a canonical order.
+ '''
+ raise NotImplemented()
+
+ def segment_batch(self, tensor_images, downsample=1):
+ '''
+ Returns a multilabel segmentation for the given batch of (RGB [-1...1])
+ images. Each pixel of the result is a torch.long indicating a
+ predicted class number. Multiple classes can be predicted for
+ the same pixel: output shape is (n, multipred, y, x), where
+ multipred is 3, 5, or 6, for how many different predicted labels can
+ be given for each pixel (depending on whether subdivision is being
+ used). If downsample is specified, then the output y and x dimensions
+ are downsampled from the original image.
+ '''
+ raise NotImplemented()
+
+ def predict_single_class(self, tensor_images, classnum, downsample=1):
+ '''
+ Given a batch of images (RGB, normalized to [-1...1]) and
+ a specific segmentation class number, returns a tuple with
+ (1) a differentiable ([0..1]) prediction score for the class
+ at every pixel of the input image.
+ (2) a binary mask showing where in the input image the
+ specified class is the best-predicted label for the pixel.
+ Does not work on subdivided labels.
+ '''
+ raise NotImplemented()
+
+class UnifiedParsingSegmenter(BaseSegmenter):
+ '''
+ This is a wrapper for a more complicated multi-class segmenter,
+ as described in https://arxiv.org/pdf/1807.10221.pdf, and as
+ released in https://github.com/CSAILVision/unifiedparsing.
+ For our purposes and to simplify processing, we do not use
+ whole-scene predictions, and we only consume part segmentations
+ for the three largest object classes (sky, building, person).
+ '''
+
+ def __init__(self, segsizes=None, segdiv=None):
+ # Create a segmentation model
+ if segsizes is None:
+ segsizes = [256]
+ if segdiv == None:
+ segdiv = 'undivided'
+ segvocab = 'upp'
+ segarch = ('resnet50', 'upernet')
+ epoch = 40
+ segmodel = load_unified_parsing_segmentation_model(
+ segarch, segvocab, epoch)
+ segmodel.cuda()
+ self.segmodel = segmodel
+ self.segsizes = segsizes
+ self.segdiv = segdiv
+ mult = 1
+ if self.segdiv == 'quad':
+ mult = 5
+ self.divmult = mult
+ # Assign class numbers for parts.
+ first_partnumber = (
+ (len(segmodel.labeldata['object']) - 1) * mult + 1 +
+ (len(segmodel.labeldata['material']) - 1))
+ # We only use parts for these three types of objects, for efficiency.
+ partobjects = ['sky', 'building', 'person']
+ partnumbers = {}
+ partnames = []
+ objectnumbers = {k: v
+ for v, k in enumerate(segmodel.labeldata['object'])}
+ part_index_translation = []
+ # We merge some classes. For example "door" is both an object
+ # and a part of a building. To avoid confusion, we just count
+ # such classes as objects, and add part scores to the same index.
+ for owner in partobjects:
+ part_list = segmodel.labeldata['object_part'][owner]
+ numeric_part_list = []
+ for part in part_list:
+ if part in objectnumbers:
+ numeric_part_list.append(objectnumbers[part])
+ elif part in partnumbers:
+ numeric_part_list.append(partnumbers[part])
+ else:
+ partnumbers[part] = len(partnames) + first_partnumber
+ partnames.append(part)
+ numeric_part_list.append(partnumbers[part])
+ part_index_translation.append(torch.tensor(numeric_part_list))
+ self.objects_with_parts = [objectnumbers[obj] for obj in partobjects]
+ self.part_index = part_index_translation
+ self.part_names = partnames
+ # For now we'll just do object and material labels.
+ self.num_classes = 1 + (
+ len(segmodel.labeldata['object']) - 1) * mult + (
+ len(segmodel.labeldata['material']) - 1) + len(partnames)
+ self.num_object_classes = len(self.segmodel.labeldata['object']) - 1
+
+ def get_label_and_category_names(self, dataset=None):
+ '''
+ Lists label and category names.
+ '''
+ # Labels are ordered as follows:
+ # 0, [object labels] [divided object labels] [materials] [parts]
+ # The zero label is reserved to mean 'no prediction'.
+ if self.segdiv == 'quad':
+ suffixes = ['t', 'l', 'b', 'r']
+ else:
+ suffixes = []
+ divided_labels = []
+ for suffix in suffixes:
+ divided_labels.extend([('%s-%s' % (label, suffix), 'part')
+ for label in self.segmodel.labeldata['object'][1:]])
+ # Create the whole list of labels
+ labelcats = (
+ [(label, 'object')
+ for label in self.segmodel.labeldata['object']] +
+ divided_labels +
+ [(label, 'material')
+ for label in self.segmodel.labeldata['material'][1:]] +
+ [(label, 'part') for label in self.part_names])
+ return labelcats, ['object', 'part', 'material']
+
+ def raw_seg_prediction(self, tensor_images, downsample=1):
+ '''
+ Generates a segmentation by applying multiresolution voting on
+ the segmentation model, using (rounded to 32 pixels) a set of
+ resolutions in the example benchmark code.
+ '''
+ y, x = tensor_images.shape[2:]
+ b = len(tensor_images)
+ tensor_images = (tensor_images + 1) / 2 * 255
+ tensor_images = torch.flip(tensor_images, (1,)) # BGR!!!?
+ tensor_images -= torch.tensor([102.9801, 115.9465, 122.7717]).to(
+ dtype=tensor_images.dtype, device=tensor_images.device
+ )[None,:,None,None]
+ seg_shape = (y // downsample, x // downsample)
+ # We want these to be multiples of 32 for the model.
+ sizes = [(s, s) for s in self.segsizes]
+ pred = {category: torch.zeros(
+ len(tensor_images), len(self.segmodel.labeldata[category]),
+ seg_shape[0], seg_shape[1]).cuda()
+ for category in ['object', 'material']}
+ part_pred = {partobj_index: torch.zeros(
+ len(tensor_images), len(partindex),
+ seg_shape[0], seg_shape[1]).cuda()
+ for partobj_index, partindex in enumerate(self.part_index)}
+ for size in sizes:
+ if size == tensor_images.shape[2:]:
+ resized = tensor_images
+ else:
+ resized = torch.nn.AdaptiveAvgPool2d(size)(tensor_images)
+ r_pred = self.segmodel(
+ dict(img=resized), seg_size=seg_shape)
+ for k in pred:
+ pred[k] += r_pred[k]
+ for k in part_pred:
+ part_pred[k] += r_pred['part'][k]
+ return pred, part_pred
+
+ def segment_batch(self, tensor_images, downsample=1):
+ '''
+ Returns a multilabel segmentation for the given batch of (RGB [-1...1])
+ images. Each pixel of the result is a torch.long indicating a
+ predicted class number. Multiple classes can be predicted for
+ the same pixel: output shape is (n, multipred, y, x), where
+ multipred is 3, 5, or 6, for how many different predicted labels can
+ be given for each pixel (depending on whether subdivision is being
+ used). If downsample is specified, then the output y and x dimensions
+ are downsampled from the original image.
+ '''
+ pred, part_pred = self.raw_seg_prediction(tensor_images,
+ downsample=downsample)
+ piece_channels = 2 if self.segdiv == 'quad' else 0
+ y, x = tensor_images.shape[2:]
+ seg_shape = (y // downsample, x // downsample)
+ segs = torch.zeros(len(tensor_images), 3 + piece_channels,
+ seg_shape[0], seg_shape[1],
+ dtype=torch.long, device=tensor_images.device)
+ _, segs[:,0] = torch.max(pred['object'], dim=1)
+ # Get materials and translate to shared numbering scheme
+ _, segs[:,1] = torch.max(pred['material'], dim=1)
+ maskout = (segs[:,1] == 0)
+ segs[:,1] += (len(self.segmodel.labeldata['object']) - 1) * self.divmult
+ segs[:,1][maskout] = 0
+ # Now deal with subparts of sky, buildings, people
+ for i, object_index in enumerate(self.objects_with_parts):
+ trans = self.part_index[i].to(segs.device)
+ # Get the argmax, and then translate to shared numbering scheme
+ seg = trans[torch.max(part_pred[i], dim=1)[1]]
+ # Only trust the parts where the prediction also predicts the
+ # owning object.
+ mask = (segs[:,0] == object_index)
+ segs[:,2][mask] = seg[mask]
+
+ if self.segdiv == 'quad':
+ segs = self.expand_segment_quad(segs, self.segdiv)
+ return segs
+
+ def predict_single_class(self, tensor_images, classnum, downsample=1):
+ '''
+ Given a batch of images (RGB, normalized to [-1...1]) and
+ a specific segmentation class number, returns a tuple with
+ (1) a differentiable ([0..1]) prediction score for the class
+ at every pixel of the input image.
+ (2) a binary mask showing where in the input image the
+ specified class is the best-predicted label for the pixel.
+ Does not work on subdivided labels.
+ '''
+ result = 0
+ pred, part_pred = self.raw_seg_prediction(tensor_images,
+ downsample=downsample)
+ material_offset = (len(self.segmodel.labeldata['object']) - 1
+ ) * self.divmult
+ if material_offset < classnum < material_offset + len(
+ self.segmodel.labeldata['material']):
+ return (
+ pred['material'][:, classnum - material_offset],
+ pred['material'].max(dim=1)[1] == classnum - material_offset)
+ mask = None
+ if classnum < len(self.segmodel.labeldata['object']):
+ result = pred['object'][:, classnum]
+ mask = (pred['object'].max(dim=1)[1] == classnum)
+ # Some objects, like 'door', are also a part of other objects,
+ # so add the part prediction also.
+ for i, object_index in enumerate(self.objects_with_parts):
+ local_index = (self.part_index[i] == classnum).nonzero()
+ if len(local_index) == 0:
+ continue
+ local_index = local_index.item()
+ # Ignore part predictions outside the mask. (We could pay
+ # atttention to and penalize such predictions.)
+ mask2 = (pred['object'].max(dim=1)[1] == object_index) * (
+ part_pred[i].max(dim=1)[1] == local_index)
+ if mask is None:
+ mask = mask2
+ else:
+ mask = torch.max(mask, mask2)
+ result = result + (part_pred[i][:, local_index])
+ assert result is not 0, 'unrecognized class %d' % classnum
+ return result, mask
+
+ def expand_segment_quad(self, segs, segdiv='quad'):
+ shape = segs.shape
+ segs[:,3:] = segs[:,0:1] # start by copying the object channel
+ num_seg_labels = self.num_object_classes
+ # For every connected component present (using generator)
+ for i, mask in component_masks(segs[:,0:1]):
+ # Figure the bounding box of the label
+ top, bottom = mask.any(dim=1).nonzero()[[0, -1], 0]
+ left, right = mask.any(dim=0).nonzero()[[0, -1], 0]
+ # Chop the bounding box into four parts
+ vmid = (top + bottom + 1) // 2
+ hmid = (left + right + 1) // 2
+ # Construct top, bottom, right, left masks
+ quad_mask = mask[None,:,:].repeat(4, 1, 1)
+ quad_mask[0, vmid:, :] = 0 # top
+ quad_mask[1, :, hmid:] = 0 # right
+ quad_mask[2, :vmid, :] = 0 # bottom
+ quad_mask[3, :, :hmid] = 0 # left
+ quad_mask = quad_mask.long()
+ # Modify extra segmentation labels by offsetting
+ segs[i,3,:,:] += quad_mask[0] * num_seg_labels
+ segs[i,4,:,:] += quad_mask[1] * (2 * num_seg_labels)
+ segs[i,3,:,:] += quad_mask[2] * (3 * num_seg_labels)
+ segs[i,4,:,:] += quad_mask[3] * (4 * num_seg_labels)
+ # remove any components that were too small to subdivide
+ mask = segs[:,3:] <= self.num_object_classes
+ segs[:,3:][mask] = 0
+ return segs
+
+class SemanticSegmenter(BaseSegmenter):
+ def __init__(self, modeldir=None, segarch=None, segvocab=None,
+ segsizes=None, segdiv=None, epoch=None):
+ # Create a segmentation model
+ if modeldir == None:
+ modeldir = 'dataset/segmodel'
+ if segvocab == None:
+ segvocab = 'baseline'
+ if segarch == None:
+ segarch = ('resnet50_dilated8', 'ppm_bilinear_deepsup')
+ if segdiv == None:
+ segdiv = 'undivided'
+ elif isinstance(segarch, str):
+ segarch = segarch.split(',')
+ segmodel = load_segmentation_model(modeldir, segarch, segvocab, epoch)
+ if segsizes is None:
+ segsizes = getattr(segmodel.meta, 'segsizes', [256])
+ self.segsizes = segsizes
+ # Verify segmentation model to has every out_channel labeled.
+ assert len(segmodel.meta.labels) == list(c for c in segmodel.modules()
+ if isinstance(c, torch.nn.Conv2d))[-1].out_channels
+ segmodel.cuda()
+ self.segmodel = segmodel
+ self.segdiv = segdiv
+ # Image normalization
+ self.bgr = (segmodel.meta.imageformat.byteorder == 'BGR')
+ self.imagemean = torch.tensor(segmodel.meta.imageformat.mean)
+ self.imagestd = torch.tensor(segmodel.meta.imageformat.stdev)
+ # Map from labels to external indexes, and labels to channel sets.
+ self.labelmap = {'-': 0}
+ self.channelmap = {'-': []}
+ self.labels = [('-', '-')]
+ num_labels = 1
+ self.num_underlying_classes = len(segmodel.meta.labels)
+ # labelmap maps names to external indexes.
+ for i, label in enumerate(segmodel.meta.labels):
+ if label.name not in self.channelmap:
+ self.channelmap[label.name] = []
+ self.channelmap[label.name].append(i)
+ if getattr(label, 'internal', None) or label.name in self.labelmap:
+ continue
+ self.labelmap[label.name] = num_labels
+ num_labels += 1
+ self.labels.append((label.name, label.category))
+ # Each category gets its own independent softmax.
+ self.category_indexes = { category.name:
+ [i for i, label in enumerate(segmodel.meta.labels)
+ if label.category == category.name]
+ for category in segmodel.meta.categories }
+ # catindexmap maps names to category internal indexes
+ self.catindexmap = {}
+ for catname, indexlist in self.category_indexes.items():
+ for index, i in enumerate(indexlist):
+ self.catindexmap[segmodel.meta.labels[i].name] = (
+ (catname, index))
+ # After the softmax, each category is mapped to external indexes.
+ self.category_map = { catname:
+ torch.tensor([
+ self.labelmap.get(segmodel.meta.labels[ind].name, 0)
+ for ind in catindex])
+ for catname, catindex in self.category_indexes.items()}
+ self.category_rules = segmodel.meta.categories
+ # Finally, naive subdivision can be applied.
+ mult = 1
+ if self.segdiv == 'quad':
+ mult = 5
+ suffixes = ['t', 'l', 'b', 'r']
+ divided_labels = []
+ for suffix in suffixes:
+ divided_labels.extend([('%s-%s' % (label, suffix), cat)
+ for label, cat in self.labels[1:]])
+ self.channelmap.update({
+ '%s-%s' % (label, suffix): self.channelmap[label]
+ for label, cat in self.labels[1:] })
+ self.labels.extend(divided_labels)
+ # For examining a single class
+ self.channellist = [self.channelmap[name] for name, _ in self.labels]
+
+ def get_label_and_category_names(self, dataset=None):
+ return self.labels, self.segmodel.categories
+
+ def segment_batch(self, tensor_images, downsample=1):
+ return self.raw_segment_batch(tensor_images, downsample)[0]
+
+ def raw_segment_batch(self, tensor_images, downsample=1):
+ pred = self.raw_seg_prediction(tensor_images, downsample)
+ catsegs = {}
+ for catkey, catindex in self.category_indexes.items():
+ _, segs = torch.max(pred[:, catindex], dim=1)
+ catsegs[catkey] = segs
+ masks = {}
+ segs = torch.zeros(len(tensor_images), len(self.category_rules),
+ pred.shape[2], pred.shape[2], device=pred.device,
+ dtype=torch.long)
+ for i, cat in enumerate(self.category_rules):
+ catmap = self.category_map[cat.name].to(pred.device)
+ translated = catmap[catsegs[cat.name]]
+ if getattr(cat, 'mask', None) is not None:
+ if cat.mask not in masks:
+ maskcat, maskind = self.catindexmap[cat.mask]
+ masks[cat.mask] = (catsegs[maskcat] == maskind)
+ translated *= masks[cat.mask].long()
+ segs[:,i] = translated
+ if self.segdiv == 'quad':
+ segs = self.expand_segment_quad(segs,
+ self.num_underlying_classes, self.segdiv)
+ return segs, pred
+
+ def raw_seg_prediction(self, tensor_images, downsample=1):
+ '''
+ Generates a segmentation by applying multiresolution voting on
+ the segmentation model, using (rounded to 32 pixels) a set of
+ resolutions in the example benchmark code.
+ '''
+ y, x = tensor_images.shape[2:]
+ b = len(tensor_images)
+ # Flip the RGB order if specified.
+ if self.bgr:
+ tensor_images = torch.flip(tensor_images, (1,))
+ # Transform from our [-1..1] range to torch standard [0..1] range
+ # and then apply normalization.
+ tensor_images = ((tensor_images + 1) / 2
+ ).sub_(self.imagemean[None,:,None,None].to(tensor_images.device)
+ ).div_(self.imagestd[None,:,None,None].to(tensor_images.device))
+ # Output shape can be downsampled.
+ seg_shape = (y // downsample, x // downsample)
+ # We want these to be multiples of 32 for the model.
+ sizes = [(s, s) for s in self.segsizes]
+ pred = torch.zeros(
+ len(tensor_images), (self.num_underlying_classes),
+ seg_shape[0], seg_shape[1]).cuda()
+ for size in sizes:
+ if size == tensor_images.shape[2:]:
+ resized = tensor_images
+ else:
+ resized = torch.nn.AdaptiveAvgPool2d(size)(tensor_images)
+ raw_pred = self.segmodel(
+ dict(img_data=resized), segSize=seg_shape)
+ softmax_pred = torch.empty_like(raw_pred)
+ for catindex in self.category_indexes.values():
+ softmax_pred[:, catindex] = torch.nn.functional.softmax(
+ raw_pred[:, catindex], dim=1)
+ pred += softmax_pred
+ return pred
+
+ def expand_segment_quad(self, segs, num_seg_labels, segdiv='quad'):
+ shape = segs.shape
+ output = segs.repeat(1, 3, 1, 1)
+ # For every connected component present (using generator)
+ for i, mask in component_masks(segs):
+ # Figure the bounding box of the label
+ top, bottom = mask.any(dim=1).nonzero()[[0, -1], 0]
+ left, right = mask.any(dim=0).nonzero()[[0, -1], 0]
+ # Chop the bounding box into four parts
+ vmid = (top + bottom + 1) // 2
+ hmid = (left + right + 1) // 2
+ # Construct top, bottom, right, left masks
+ quad_mask = mask[None,:,:].repeat(4, 1, 1)
+ quad_mask[0, vmid:, :] = 0 # top
+ quad_mask[1, :, hmid:] = 0 # right
+ quad_mask[2, :vmid, :] = 0 # bottom
+ quad_mask[3, :, :hmid] = 0 # left
+ quad_mask = quad_mask.long()
+ # Modify extra segmentation labels by offsetting
+ output[i,1,:,:] += quad_mask[0] * num_seg_labels
+ output[i,2,:,:] += quad_mask[1] * (2 * num_seg_labels)
+ output[i,1,:,:] += quad_mask[2] * (3 * num_seg_labels)
+ output[i,2,:,:] += quad_mask[3] * (4 * num_seg_labels)
+ return output
+
+ def predict_single_class(self, tensor_images, classnum, downsample=1):
+ '''
+ Given a batch of images (RGB, normalized to [-1...1]) and
+ a specific segmentation class number, returns a tuple with
+ (1) a differentiable ([0..1]) prediction score for the class
+ at every pixel of the input image.
+ (2) a binary mask showing where in the input image the
+ specified class is the best-predicted label for the pixel.
+ Does not work on subdivided labels.
+ '''
+ seg, pred = self.raw_segment_batch(tensor_images,
+ downsample=downsample)
+ result = pred[:,self.channellist[classnum]].sum(dim=1)
+ mask = (seg == classnum).max(1)[0]
+ return result, mask
+
+def component_masks(segmentation_batch):
+ '''
+ Splits connected components into regions (slower, requires cpu).
+ '''
+ npbatch = segmentation_batch.cpu().numpy()
+ for i in range(segmentation_batch.shape[0]):
+ labeled, num = skimage.morphology.label(npbatch[i][0], return_num=True)
+ labeled = torch.from_numpy(labeled).to(segmentation_batch.device)
+ for label in range(1, num):
+ yield i, (labeled == label)
+
+def load_unified_parsing_segmentation_model(segmodel_arch, segvocab, epoch):
+ segmodel_dir = 'dataset/segmodel/%s-%s-%s' % ((segvocab,) + segmodel_arch)
+ # Load json of class names and part/object structure
+ with open(os.path.join(segmodel_dir, 'labels.json')) as f:
+ labeldata = json.load(f)
+ nr_classes={k: len(labeldata[k])
+ for k in ['object', 'scene', 'material']}
+ nr_classes['part'] = sum(len(p) for p in labeldata['object_part'].values())
+ # Create a segmentation model
+ segbuilder = upsegmodel.ModelBuilder()
+ # example segmodel_arch = ('resnet101', 'upernet')
+ seg_encoder = segbuilder.build_encoder(
+ arch=segmodel_arch[0],
+ fc_dim=2048,
+ weights=os.path.join(segmodel_dir, 'encoder_epoch_%d.pth' % epoch))
+ seg_decoder = segbuilder.build_decoder(
+ arch=segmodel_arch[1],
+ fc_dim=2048, use_softmax=True,
+ nr_classes=nr_classes,
+ weights=os.path.join(segmodel_dir, 'decoder_epoch_%d.pth' % epoch))
+ segmodel = upsegmodel.SegmentationModule(
+ seg_encoder, seg_decoder, labeldata)
+ segmodel.categories = ['object', 'part', 'material']
+ segmodel.eval()
+ return segmodel
+
+def load_segmentation_model(modeldir, segmodel_arch, segvocab, epoch=None):
+ # Load csv of class names
+ segmodel_dir = 'dataset/segmodel/%s-%s-%s' % ((segvocab,) + segmodel_arch)
+ with open(os.path.join(segmodel_dir, 'labels.json')) as f:
+ labeldata = EasyDict(json.load(f))
+ # Automatically pick the last epoch available.
+ if epoch is None:
+ choices = [os.path.basename(n)[14:-4] for n in
+ glob.glob(os.path.join(segmodel_dir, 'encoder_epoch_*.pth'))]
+ epoch = max([int(c) for c in choices if c.isdigit()])
+ # Create a segmentation model
+ segbuilder = segmodel_module.ModelBuilder()
+ # example segmodel_arch = ('resnet101', 'upernet')
+ seg_encoder = segbuilder.build_encoder(
+ arch=segmodel_arch[0],
+ fc_dim=2048,
+ weights=os.path.join(segmodel_dir, 'encoder_epoch_%d.pth' % epoch))
+ seg_decoder = segbuilder.build_decoder(
+ arch=segmodel_arch[1],
+ fc_dim=2048, inference=True, num_class=len(labeldata.labels),
+ weights=os.path.join(segmodel_dir, 'decoder_epoch_%d.pth' % epoch))
+ segmodel = segmodel_module.SegmentationModule(seg_encoder, seg_decoder,
+ torch.nn.NLLLoss(ignore_index=-1))
+ segmodel.categories = [cat.name for cat in labeldata.categories]
+ segmodel.labels = [label.name for label in labeldata.labels]
+ categories = OrderedDict()
+ label_category = numpy.zeros(len(segmodel.labels), dtype=int)
+ for i, label in enumerate(labeldata.labels):
+ label_category[i] = segmodel.categories.index(label.category)
+ segmodel.meta = labeldata
+ segmodel.eval()
+ return segmodel
+
+def ensure_upp_segmenter_downloaded(directory):
+ baseurl = 'http://netdissect.csail.mit.edu/data/segmodel'
+ dirname = 'upp-resnet50-upernet'
+ files = ['decoder_epoch_40.pth', 'encoder_epoch_40.pth', 'labels.json']
+ download_dir = os.path.join(directory, dirname)
+ os.makedirs(download_dir, exist_ok=True)
+ for fn in files:
+ if os.path.isfile(os.path.join(download_dir, fn)):
+ continue # Skip files already downloaded
+ url = '%s/%s/%s' % (baseurl, dirname, fn)
+ print('Downloading %s' % url)
+ urlretrieve(url, os.path.join(download_dir, fn))
+ assert os.path.isfile(os.path.join(directory, dirname, 'labels.json'))
+
+def test_main():
+ '''
+ Test the unified segmenter.
+ '''
+ from PIL import Image
+ testim = Image.open('script/testdata/test_church_242.jpg')
+ tensor_im = (torch.from_numpy(numpy.asarray(testim)).permute(2, 0, 1)
+ .float() / 255 * 2 - 1)[None, :, :, :].cuda()
+ segmenter = UnifiedParsingSegmenter()
+ seg = segmenter.segment_batch(tensor_im)
+ bc = torch.bincount(seg.view(-1))
+ labels, cats = segmenter.get_label_and_category_names()
+ for label in bc.nonzero()[:,0]:
+ if label.item():
+ # What is the prediction for this class?
+ pred, mask = segmenter.predict_single_class(tensor_im, label.item())
+ assert mask.sum().item() == bc[label].item()
+ assert len(((seg == label).max(1)[0] - mask).nonzero()) == 0
+ inside_pred = pred[mask].mean().item()
+ outside_pred = pred[~mask].mean().item()
+ print('%s (%s, #%d): %d pixels, pred %.2g inside %.2g outside' %
+ (labels[label.item()] + (label.item(), bc[label].item(),
+ inside_pred, outside_pred)))
+
+if __name__ == '__main__':
+ test_main()
diff --git a/netdissect/segmodel/__init__.py b/netdissect/segmodel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76b40a0a36bc2976f185dbdc344c5a7c09b65920
--- /dev/null
+++ b/netdissect/segmodel/__init__.py
@@ -0,0 +1 @@
+from .models import ModelBuilder, SegmentationModule
diff --git a/netdissect/segmodel/colors150.npy b/netdissect/segmodel/colors150.npy
new file mode 100644
index 0000000000000000000000000000000000000000..2384b386dabded09c47329a360987a0d7f67d697
--- /dev/null
+++ b/netdissect/segmodel/colors150.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9823be654b7a9c135e355952b580f567d5127e98d99ec792ebc349d5fc8c137
+size 578
diff --git a/netdissect/segmodel/models.py b/netdissect/segmodel/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceb6f2ce21720722d5d8c9ee4f7e015ad06a9647
--- /dev/null
+++ b/netdissect/segmodel/models.py
@@ -0,0 +1,558 @@
+import torch
+import torch.nn as nn
+import torchvision
+from . import resnet, resnext
+try:
+ from lib.nn import SynchronizedBatchNorm2d
+except ImportError:
+ from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d
+
+
+class SegmentationModuleBase(nn.Module):
+ def __init__(self):
+ super(SegmentationModuleBase, self).__init__()
+
+ def pixel_acc(self, pred, label):
+ _, preds = torch.max(pred, dim=1)
+ valid = (label >= 0).long()
+ acc_sum = torch.sum(valid * (preds == label).long())
+ pixel_sum = torch.sum(valid)
+ acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
+ return acc
+
+
+class SegmentationModule(SegmentationModuleBase):
+ def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
+ super(SegmentationModule, self).__init__()
+ self.encoder = net_enc
+ self.decoder = net_dec
+ self.crit = crit
+ self.deep_sup_scale = deep_sup_scale
+
+ def forward(self, feed_dict, *, segSize=None):
+ if segSize is None: # training
+ if self.deep_sup_scale is not None: # use deep supervision technique
+ (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
+ else:
+ pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
+
+ loss = self.crit(pred, feed_dict['seg_label'])
+ if self.deep_sup_scale is not None:
+ loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
+ loss = loss + loss_deepsup * self.deep_sup_scale
+
+ acc = self.pixel_acc(pred, feed_dict['seg_label'])
+ return loss, acc
+ else: # inference
+ pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
+ return pred
+
+
+def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=has_bias)
+
+
+def conv3x3_bn_relu(in_planes, out_planes, stride=1):
+ return nn.Sequential(
+ conv3x3(in_planes, out_planes, stride),
+ SynchronizedBatchNorm2d(out_planes),
+ nn.ReLU(inplace=True),
+ )
+
+
+class ModelBuilder():
+ # custom weights initialization
+ def weights_init(self, m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.kaiming_normal_(m.weight.data)
+ elif classname.find('BatchNorm') != -1:
+ m.weight.data.fill_(1.)
+ m.bias.data.fill_(1e-4)
+ #elif classname.find('Linear') != -1:
+ # m.weight.data.normal_(0.0, 0.0001)
+
+ def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''):
+ pretrained = True if len(weights) == 0 else False
+ if arch == 'resnet34':
+ raise NotImplementedError
+ orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
+ net_encoder = Resnet(orig_resnet)
+ elif arch == 'resnet34_dilated8':
+ raise NotImplementedError
+ orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet,
+ dilate_scale=8)
+ elif arch == 'resnet34_dilated16':
+ raise NotImplementedError
+ orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet,
+ dilate_scale=16)
+ elif arch == 'resnet50':
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
+ net_encoder = Resnet(orig_resnet)
+ elif arch == 'resnet50_dilated8':
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet,
+ dilate_scale=8)
+ elif arch == 'resnet50_dilated16':
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet,
+ dilate_scale=16)
+ elif arch == 'resnet101':
+ orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
+ net_encoder = Resnet(orig_resnet)
+ elif arch == 'resnet101_dilated8':
+ orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet,
+ dilate_scale=8)
+ elif arch == 'resnet101_dilated16':
+ orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
+ net_encoder = ResnetDilated(orig_resnet,
+ dilate_scale=16)
+ elif arch == 'resnext101':
+ orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained)
+ net_encoder = Resnet(orig_resnext) # we can still use class Resnet
+ else:
+ raise Exception('Architecture undefined!')
+
+ # net_encoder.apply(self.weights_init)
+ if len(weights) > 0:
+ # print('Loading weights for net_encoder')
+ net_encoder.load_state_dict(
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
+ return net_encoder
+
+ def build_decoder(self, arch='ppm_bilinear_deepsup',
+ fc_dim=512, num_class=150,
+ weights='', inference=False, use_softmax=False):
+ if arch == 'c1_bilinear_deepsup':
+ net_decoder = C1BilinearDeepSup(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ inference=inference,
+ use_softmax=use_softmax)
+ elif arch == 'c1_bilinear':
+ net_decoder = C1Bilinear(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ inference=inference,
+ use_softmax=use_softmax)
+ elif arch == 'ppm_bilinear':
+ net_decoder = PPMBilinear(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ inference=inference,
+ use_softmax=use_softmax)
+ elif arch == 'ppm_bilinear_deepsup':
+ net_decoder = PPMBilinearDeepsup(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ inference=inference,
+ use_softmax=use_softmax)
+ elif arch == 'upernet_lite':
+ net_decoder = UPerNet(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ inference=inference,
+ use_softmax=use_softmax,
+ fpn_dim=256)
+ elif arch == 'upernet':
+ net_decoder = UPerNet(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ inference=inference,
+ use_softmax=use_softmax,
+ fpn_dim=512)
+ elif arch == 'upernet_tmp':
+ net_decoder = UPerNetTmp(
+ num_class=num_class,
+ fc_dim=fc_dim,
+ inference=inference,
+ use_softmax=use_softmax,
+ fpn_dim=512)
+ else:
+ raise Exception('Architecture undefined!')
+
+ net_decoder.apply(self.weights_init)
+ if len(weights) > 0:
+ # print('Loading weights for net_decoder')
+ net_decoder.load_state_dict(
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
+ return net_decoder
+
+
+class Resnet(nn.Module):
+ def __init__(self, orig_resnet):
+ super(Resnet, self).__init__()
+
+ # take pretrained resnet, except AvgPool and FC
+ self.conv1 = orig_resnet.conv1
+ self.bn1 = orig_resnet.bn1
+ self.relu1 = orig_resnet.relu1
+ self.conv2 = orig_resnet.conv2
+ self.bn2 = orig_resnet.bn2
+ self.relu2 = orig_resnet.relu2
+ self.conv3 = orig_resnet.conv3
+ self.bn3 = orig_resnet.bn3
+ self.relu3 = orig_resnet.relu3
+ self.maxpool = orig_resnet.maxpool
+ self.layer1 = orig_resnet.layer1
+ self.layer2 = orig_resnet.layer2
+ self.layer3 = orig_resnet.layer3
+ self.layer4 = orig_resnet.layer4
+
+ def forward(self, x, return_feature_maps=False):
+ conv_out = []
+
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x); conv_out.append(x);
+ x = self.layer2(x); conv_out.append(x);
+ x = self.layer3(x); conv_out.append(x);
+ x = self.layer4(x); conv_out.append(x);
+
+ if return_feature_maps:
+ return conv_out
+ return [x]
+
+
+class ResnetDilated(nn.Module):
+ def __init__(self, orig_resnet, dilate_scale=8):
+ super(ResnetDilated, self).__init__()
+ from functools import partial
+
+ if dilate_scale == 8:
+ orig_resnet.layer3.apply(
+ partial(self._nostride_dilate, dilate=2))
+ orig_resnet.layer4.apply(
+ partial(self._nostride_dilate, dilate=4))
+ elif dilate_scale == 16:
+ orig_resnet.layer4.apply(
+ partial(self._nostride_dilate, dilate=2))
+
+ # take pretrained resnet, except AvgPool and FC
+ self.conv1 = orig_resnet.conv1
+ self.bn1 = orig_resnet.bn1
+ self.relu1 = orig_resnet.relu1
+ self.conv2 = orig_resnet.conv2
+ self.bn2 = orig_resnet.bn2
+ self.relu2 = orig_resnet.relu2
+ self.conv3 = orig_resnet.conv3
+ self.bn3 = orig_resnet.bn3
+ self.relu3 = orig_resnet.relu3
+ self.maxpool = orig_resnet.maxpool
+ self.layer1 = orig_resnet.layer1
+ self.layer2 = orig_resnet.layer2
+ self.layer3 = orig_resnet.layer3
+ self.layer4 = orig_resnet.layer4
+
+ def _nostride_dilate(self, m, dilate):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ # the convolution with stride
+ if m.stride == (2, 2):
+ m.stride = (1, 1)
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate//2, dilate//2)
+ m.padding = (dilate//2, dilate//2)
+ # other convoluions
+ else:
+ if m.kernel_size == (3, 3):
+ m.dilation = (dilate, dilate)
+ m.padding = (dilate, dilate)
+
+ def forward(self, x, return_feature_maps=False):
+ conv_out = []
+
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x); conv_out.append(x);
+ x = self.layer2(x); conv_out.append(x);
+ x = self.layer3(x); conv_out.append(x);
+ x = self.layer4(x); conv_out.append(x);
+
+ if return_feature_maps:
+ return conv_out
+ return [x]
+
+
+# last conv, bilinear upsample
+class C1BilinearDeepSup(nn.Module):
+ def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False):
+ super(C1BilinearDeepSup, self).__init__()
+ self.use_softmax = use_softmax
+ self.inference = inference
+
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
+
+ # last conv
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+
+ x = self.cbr(conv5)
+ x = self.conv_last(x)
+
+ if self.inference or self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ if self.use_softmax:
+ x = nn.functional.softmax(x, dim=1)
+ return x
+
+ # deep sup
+ conv4 = conv_out[-2]
+ _ = self.cbr_deepsup(conv4)
+ _ = self.conv_last_deepsup(_)
+
+ x = nn.functional.log_softmax(x, dim=1)
+ _ = nn.functional.log_softmax(_, dim=1)
+
+ return (x, _)
+
+
+# last conv, bilinear upsample
+class C1Bilinear(nn.Module):
+ def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False):
+ super(C1Bilinear, self).__init__()
+ self.use_softmax = use_softmax
+ self.inference = inference
+
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
+
+ # last conv
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+ x = self.cbr(conv5)
+ x = self.conv_last(x)
+
+ if self.inference or self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ if self.use_softmax:
+ x = nn.functional.softmax(x, dim=1)
+ else:
+ x = nn.functional.log_softmax(x, dim=1)
+
+ return x
+
+
+# pyramid pooling, bilinear upsample
+class PPMBilinear(nn.Module):
+ def __init__(self, num_class=150, fc_dim=4096,
+ inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)):
+ super(PPMBilinear, self).__init__()
+ self.use_softmax = use_softmax
+ self.inference = inference
+
+ self.ppm = []
+ for scale in pool_scales:
+ self.ppm.append(nn.Sequential(
+ nn.AdaptiveAvgPool2d(scale),
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
+ SynchronizedBatchNorm2d(512),
+ nn.ReLU(inplace=True)
+ ))
+ self.ppm = nn.ModuleList(self.ppm)
+
+ self.conv_last = nn.Sequential(
+ nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
+ kernel_size=3, padding=1, bias=False),
+ SynchronizedBatchNorm2d(512),
+ nn.ReLU(inplace=True),
+ nn.Dropout2d(0.1),
+ nn.Conv2d(512, num_class, kernel_size=1)
+ )
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+
+ input_size = conv5.size()
+ ppm_out = [conv5]
+ for pool_scale in self.ppm:
+ ppm_out.append(nn.functional.interpolate(
+ pool_scale(conv5),
+ (input_size[2], input_size[3]),
+ mode='bilinear', align_corners=False))
+ ppm_out = torch.cat(ppm_out, 1)
+
+ x = self.conv_last(ppm_out)
+
+ if self.inference or self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ if self.use_softmax:
+ x = nn.functional.softmax(x, dim=1)
+ else:
+ x = nn.functional.log_softmax(x, dim=1)
+ return x
+
+
+# pyramid pooling, bilinear upsample
+class PPMBilinearDeepsup(nn.Module):
+ def __init__(self, num_class=150, fc_dim=4096,
+ inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)):
+ super(PPMBilinearDeepsup, self).__init__()
+ self.use_softmax = use_softmax
+ self.inference = inference
+
+ self.ppm = []
+ for scale in pool_scales:
+ self.ppm.append(nn.Sequential(
+ nn.AdaptiveAvgPool2d(scale),
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
+ SynchronizedBatchNorm2d(512),
+ nn.ReLU(inplace=True)
+ ))
+ self.ppm = nn.ModuleList(self.ppm)
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
+
+ self.conv_last = nn.Sequential(
+ nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
+ kernel_size=3, padding=1, bias=False),
+ SynchronizedBatchNorm2d(512),
+ nn.ReLU(inplace=True),
+ nn.Dropout2d(0.1),
+ nn.Conv2d(512, num_class, kernel_size=1)
+ )
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
+ self.dropout_deepsup = nn.Dropout2d(0.1)
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+
+ input_size = conv5.size()
+ ppm_out = [conv5]
+ for pool_scale in self.ppm:
+ ppm_out.append(nn.functional.interpolate(
+ pool_scale(conv5),
+ (input_size[2], input_size[3]),
+ mode='bilinear', align_corners=False))
+ ppm_out = torch.cat(ppm_out, 1)
+
+ x = self.conv_last(ppm_out)
+
+ if self.inference or self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ if self.use_softmax:
+ x = nn.functional.softmax(x, dim=1)
+ return x
+
+ # deep sup
+ conv4 = conv_out[-2]
+ _ = self.cbr_deepsup(conv4)
+ _ = self.dropout_deepsup(_)
+ _ = self.conv_last_deepsup(_)
+
+ x = nn.functional.log_softmax(x, dim=1)
+ _ = nn.functional.log_softmax(_, dim=1)
+
+ return (x, _)
+
+
+# upernet
+class UPerNet(nn.Module):
+ def __init__(self, num_class=150, fc_dim=4096,
+ inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6),
+ fpn_inplanes=(256,512,1024,2048), fpn_dim=256):
+ super(UPerNet, self).__init__()
+ self.use_softmax = use_softmax
+ self.inference = inference
+
+ # PPM Module
+ self.ppm_pooling = []
+ self.ppm_conv = []
+
+ for scale in pool_scales:
+ self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
+ self.ppm_conv.append(nn.Sequential(
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
+ SynchronizedBatchNorm2d(512),
+ nn.ReLU(inplace=True)
+ ))
+ self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
+ self.ppm_conv = nn.ModuleList(self.ppm_conv)
+ self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1)
+
+ # FPN Module
+ self.fpn_in = []
+ for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
+ self.fpn_in.append(nn.Sequential(
+ nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
+ SynchronizedBatchNorm2d(fpn_dim),
+ nn.ReLU(inplace=True)
+ ))
+ self.fpn_in = nn.ModuleList(self.fpn_in)
+
+ self.fpn_out = []
+ for i in range(len(fpn_inplanes) - 1): # skip the top layer
+ self.fpn_out.append(nn.Sequential(
+ conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
+ ))
+ self.fpn_out = nn.ModuleList(self.fpn_out)
+
+ self.conv_last = nn.Sequential(
+ conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),
+ nn.Conv2d(fpn_dim, num_class, kernel_size=1)
+ )
+
+ def forward(self, conv_out, segSize=None):
+ conv5 = conv_out[-1]
+
+ input_size = conv5.size()
+ ppm_out = [conv5]
+ for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
+ ppm_out.append(pool_conv(nn.functional.interploate(
+ pool_scale(conv5),
+ (input_size[2], input_size[3]),
+ mode='bilinear', align_corners=False)))
+ ppm_out = torch.cat(ppm_out, 1)
+ f = self.ppm_last_conv(ppm_out)
+
+ fpn_feature_list = [f]
+ for i in reversed(range(len(conv_out) - 1)):
+ conv_x = conv_out[i]
+ conv_x = self.fpn_in[i](conv_x) # lateral branch
+
+ f = nn.functional.interpolate(
+ f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch
+ f = conv_x + f
+
+ fpn_feature_list.append(self.fpn_out[i](f))
+
+ fpn_feature_list.reverse() # [P2 - P5]
+ output_size = fpn_feature_list[0].size()[2:]
+ fusion_list = [fpn_feature_list[0]]
+ for i in range(1, len(fpn_feature_list)):
+ fusion_list.append(nn.functional.interpolate(
+ fpn_feature_list[i],
+ output_size,
+ mode='bilinear', align_corners=False))
+ fusion_out = torch.cat(fusion_list, 1)
+ x = self.conv_last(fusion_out)
+
+ if self.inference or self.use_softmax: # is True during inference
+ x = nn.functional.interpolate(
+ x, size=segSize, mode='bilinear', align_corners=False)
+ if self.use_softmax:
+ x = nn.functional.softmax(x, dim=1)
+ return x
+
+ x = nn.functional.log_softmax(x, dim=1)
+
+ return x
diff --git a/netdissect/segmodel/object150_info.csv b/netdissect/segmodel/object150_info.csv
new file mode 100644
index 0000000000000000000000000000000000000000..8b34d8f3874a38b96894863c5458a7c3c2b0e2e6
--- /dev/null
+++ b/netdissect/segmodel/object150_info.csv
@@ -0,0 +1,151 @@
+Idx,Ratio,Train,Val,Stuff,Name
+1,0.1576,11664,1172,1,wall
+2,0.1072,6046,612,1,building;edifice
+3,0.0878,8265,796,1,sky
+4,0.0621,9336,917,1,floor;flooring
+5,0.0480,6678,641,0,tree
+6,0.0450,6604,643,1,ceiling
+7,0.0398,4023,408,1,road;route
+8,0.0231,1906,199,0,bed
+9,0.0198,4688,460,0,windowpane;window
+10,0.0183,2423,225,1,grass
+11,0.0181,2874,294,0,cabinet
+12,0.0166,3068,310,1,sidewalk;pavement
+13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
+14,0.0151,1804,190,1,earth;ground
+15,0.0118,6666,796,0,door;double;door
+16,0.0110,4269,411,0,table
+17,0.0109,1691,160,1,mountain;mount
+18,0.0104,3999,441,0,plant;flora;plant;life
+19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
+20,0.0103,3261,318,0,chair
+21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
+22,0.0074,709,75,1,water
+23,0.0067,3296,315,0,painting;picture
+24,0.0065,1191,106,0,sofa;couch;lounge
+25,0.0061,1516,162,0,shelf
+26,0.0060,667,69,1,house
+27,0.0053,651,57,1,sea
+28,0.0052,1847,224,0,mirror
+29,0.0046,1158,128,1,rug;carpet;carpeting
+30,0.0044,480,44,1,field
+31,0.0044,1172,98,0,armchair
+32,0.0044,1292,184,0,seat
+33,0.0033,1386,138,0,fence;fencing
+34,0.0031,698,61,0,desk
+35,0.0030,781,73,0,rock;stone
+36,0.0027,380,43,0,wardrobe;closet;press
+37,0.0026,3089,302,0,lamp
+38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
+39,0.0024,804,99,0,railing;rail
+40,0.0023,1453,153,0,cushion
+41,0.0023,411,37,0,base;pedestal;stand
+42,0.0022,1440,162,0,box
+43,0.0022,800,77,0,column;pillar
+44,0.0020,2650,298,0,signboard;sign
+45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
+46,0.0019,367,36,0,counter
+47,0.0018,311,30,1,sand
+48,0.0018,1181,122,0,sink
+49,0.0018,287,23,1,skyscraper
+50,0.0018,468,38,0,fireplace;hearth;open;fireplace
+51,0.0018,402,43,0,refrigerator;icebox
+52,0.0018,130,12,1,grandstand;covered;stand
+53,0.0018,561,64,1,path
+54,0.0017,880,102,0,stairs;steps
+55,0.0017,86,12,1,runway
+56,0.0017,172,11,0,case;display;case;showcase;vitrine
+57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
+58,0.0017,930,109,0,pillow
+59,0.0015,139,18,0,screen;door;screen
+60,0.0015,564,52,1,stairway;staircase
+61,0.0015,320,26,1,river
+62,0.0015,261,29,1,bridge;span
+63,0.0014,275,22,0,bookcase
+64,0.0014,335,60,0,blind;screen
+65,0.0014,792,75,0,coffee;table;cocktail;table
+66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
+67,0.0014,1309,138,0,flower
+68,0.0013,1112,113,0,book
+69,0.0013,266,27,1,hill
+70,0.0013,659,66,0,bench
+71,0.0012,331,31,0,countertop
+72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
+73,0.0012,369,36,0,palm;palm;tree
+74,0.0012,144,9,0,kitchen;island
+75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
+76,0.0010,324,33,0,swivel;chair
+77,0.0009,304,27,0,boat
+78,0.0009,170,20,0,bar
+79,0.0009,68,6,0,arcade;machine
+80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
+81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
+82,0.0008,492,49,0,towel
+83,0.0008,2510,269,0,light;light;source
+84,0.0008,440,39,0,truck;motortruck
+85,0.0008,147,18,1,tower
+86,0.0008,583,56,0,chandelier;pendant;pendent
+87,0.0007,533,61,0,awning;sunshade;sunblind
+88,0.0007,1989,239,0,streetlight;street;lamp
+89,0.0007,71,5,0,booth;cubicle;stall;kiosk
+90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
+91,0.0007,135,12,0,airplane;aeroplane;plane
+92,0.0007,83,5,1,dirt;track
+93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
+94,0.0006,1003,104,0,pole
+95,0.0006,182,12,1,land;ground;soil
+96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
+97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
+98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
+99,0.0006,965,114,0,bottle
+100,0.0006,117,13,0,buffet;counter;sideboard
+101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
+102,0.0006,108,9,1,stage
+103,0.0006,557,55,0,van
+104,0.0006,52,4,0,ship
+105,0.0005,99,5,0,fountain
+106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
+107,0.0005,292,31,0,canopy
+108,0.0005,77,9,0,washer;automatic;washer;washing;machine
+109,0.0005,340,38,0,plaything;toy
+110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
+111,0.0005,465,49,0,stool
+112,0.0005,50,4,0,barrel;cask
+113,0.0005,622,75,0,basket;handbasket
+114,0.0005,80,9,1,waterfall;falls
+115,0.0005,59,3,0,tent;collapsible;shelter
+116,0.0005,531,72,0,bag
+117,0.0005,282,30,0,minibike;motorbike
+118,0.0005,73,7,0,cradle
+119,0.0005,435,44,0,oven
+120,0.0005,136,25,0,ball
+121,0.0005,116,24,0,food;solid;food
+122,0.0004,266,31,0,step;stair
+123,0.0004,58,12,0,tank;storage;tank
+124,0.0004,418,83,0,trade;name;brand;name;brand;marque
+125,0.0004,319,43,0,microwave;microwave;oven
+126,0.0004,1193,139,0,pot;flowerpot
+127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
+128,0.0004,347,36,0,bicycle;bike;wheel;cycle
+129,0.0004,52,5,1,lake
+130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
+131,0.0004,108,13,0,screen;silver;screen;projection;screen
+132,0.0004,201,30,0,blanket;cover
+133,0.0004,285,21,0,sculpture
+134,0.0004,268,27,0,hood;exhaust;hood
+135,0.0003,1020,108,0,sconce
+136,0.0003,1282,122,0,vase
+137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
+138,0.0003,453,57,0,tray
+139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
+140,0.0003,397,44,0,fan
+141,0.0003,92,8,1,pier;wharf;wharfage;dock
+142,0.0003,228,18,0,crt;screen
+143,0.0003,570,59,0,plate
+144,0.0003,217,22,0,monitor;monitoring;device
+145,0.0003,206,19,0,bulletin;board;notice;board
+146,0.0003,130,14,0,shower
+147,0.0003,178,28,0,radiator
+148,0.0002,504,57,0,glass;drinking;glass
+149,0.0002,775,96,0,clock
+150,0.0002,421,56,0,flag
diff --git a/netdissect/segmodel/resnet.py b/netdissect/segmodel/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea5fdf82fafa3058c5f00074d55fbb1e584d5865
--- /dev/null
+++ b/netdissect/segmodel/resnet.py
@@ -0,0 +1,235 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import math
+try:
+ from lib.nn import SynchronizedBatchNorm2d
+except ImportError:
+ from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d
+
+try:
+ from urllib import urlretrieve
+except ImportError:
+ from urllib.request import urlretrieve
+
+
+__all__ = ['ResNet', 'resnet50', 'resnet101'] # resnet101 is coming soon!
+
+
+model_urls = {
+ 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
+ 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = SynchronizedBatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = SynchronizedBatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = SynchronizedBatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = SynchronizedBatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = SynchronizedBatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000):
+ self.inplanes = 128
+ super(ResNet, self).__init__()
+ self.conv1 = conv3x3(3, 64, stride=2)
+ self.bn1 = SynchronizedBatchNorm2d(64)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(64, 64)
+ self.bn2 = SynchronizedBatchNorm2d(64)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = conv3x3(64, 128)
+ self.bn3 = SynchronizedBatchNorm2d(128)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, SynchronizedBatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ SynchronizedBatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+'''
+def resnet18(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet18']))
+ return model
+
+
+def resnet34(pretrained=False, **kwargs):
+ """Constructs a ResNet-34 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet34']))
+ return model
+'''
+
+def resnet50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
+ return model
+
+
+def resnet101(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
+ return model
+
+# def resnet152(pretrained=False, **kwargs):
+# """Constructs a ResNet-152 model.
+#
+# Args:
+# pretrained (bool): If True, returns a model pre-trained on Places
+# """
+# model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+# if pretrained:
+# model.load_state_dict(load_url(model_urls['resnet152']))
+# return model
+
+def load_url(url, model_dir='./pretrained', map_location=None):
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+ filename = url.split('/')[-1]
+ cached_file = os.path.join(model_dir, filename)
+ if not os.path.exists(cached_file):
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+ urlretrieve(url, cached_file)
+ return torch.load(cached_file, map_location=map_location)
diff --git a/netdissect/segmodel/resnext.py b/netdissect/segmodel/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdbb7461a6c8eb126717967cdca5d5ce392aecea
--- /dev/null
+++ b/netdissect/segmodel/resnext.py
@@ -0,0 +1,182 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import math
+try:
+ from lib.nn import SynchronizedBatchNorm2d
+except ImportError:
+ from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d
+try:
+ from urllib import urlretrieve
+except ImportError:
+ from urllib.request import urlretrieve
+
+
+__all__ = ['ResNeXt', 'resnext101'] # support resnext 101
+
+
+model_urls = {
+ #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',
+ 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class GroupBottleneck(nn.Module):
+ expansion = 2
+
+ def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):
+ super(GroupBottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = SynchronizedBatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, groups=groups, bias=False)
+ self.bn2 = SynchronizedBatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
+ self.bn3 = SynchronizedBatchNorm2d(planes * 2)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNeXt(nn.Module):
+
+ def __init__(self, block, layers, groups=32, num_classes=1000):
+ self.inplanes = 128
+ super(ResNeXt, self).__init__()
+ self.conv1 = conv3x3(3, 64, stride=2)
+ self.bn1 = SynchronizedBatchNorm2d(64)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(64, 64)
+ self.bn2 = SynchronizedBatchNorm2d(64)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = conv3x3(64, 128)
+ self.bn3 = SynchronizedBatchNorm2d(128)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)
+ self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)
+ self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)
+ self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.fc = nn.Linear(1024 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, SynchronizedBatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1, groups=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ SynchronizedBatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, groups, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=groups))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+'''
+def resnext50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnext50']), strict=False)
+ return model
+'''
+
+
+def resnext101(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
+ return model
+
+
+# def resnext152(pretrained=False, **kwargs):
+# """Constructs a ResNeXt-152 model.
+#
+# Args:
+# pretrained (bool): If True, returns a model pre-trained on Places
+# """
+# model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs)
+# if pretrained:
+# model.load_state_dict(load_url(model_urls['resnext152']))
+# return model
+
+
+def load_url(url, model_dir='./pretrained', map_location=None):
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+ filename = url.split('/')[-1]
+ cached_file = os.path.join(model_dir, filename)
+ if not os.path.exists(cached_file):
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+ urlretrieve(url, cached_file)
+ return torch.load(cached_file, map_location=map_location)
diff --git a/netdissect/segviz.py b/netdissect/segviz.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb954317aaf0fd6e31b6216cc7a59f01a5fb0bd
--- /dev/null
+++ b/netdissect/segviz.py
@@ -0,0 +1,283 @@
+import numpy, scipy
+
+def segment_visualization(seg, size):
+ result = numpy.zeros((seg.shape[1] * seg.shape[2], 3), dtype=numpy.uint8)
+ flatseg = seg.reshape(seg.shape[0], seg.shape[1] * seg.shape[2])
+ bc = numpy.bincount(flatseg.flatten())
+ top = numpy.argsort(-bc)
+ # In a multilabel segmentation, we can't draw everything.
+ # Draw the fewest-pixel labels last. (We could pick the opposite order.)
+ for label in top:
+ if label == 0:
+ continue
+ if bc[label] == 0:
+ break
+ bitmap = ((flatseg == label).sum(axis=0) > 0)
+ result[bitmap] = high_contrast_arr[label % len(high_contrast_arr)]
+ result = result.reshape((seg.shape[1], seg.shape[2], 3))
+ if seg.shape[1:] != size:
+ result = scipy.misc.imresize(result, size, interp='nearest')
+ return result
+
+# A palette that maximizes perceptual contrast between entries.
+# https://stackoverflow.com/questions/33295120
+high_contrast = [
+ [0, 0, 0], [255, 255, 0], [28, 230, 255], [255, 52, 255],
+ [255, 74, 70], [0, 137, 65], [0, 111, 166], [163, 0, 89],
+ [255, 219, 229], [122, 73, 0], [0, 0, 166], [99, 255, 172],
+ [183, 151, 98], [0, 77, 67], [143, 176, 255], [153, 125, 135],
+ [90, 0, 7], [128, 150, 147], [254, 255, 230], [27, 68, 0],
+ [79, 198, 1], [59, 93, 255], [74, 59, 83], [255, 47, 128],
+ [97, 97, 90], [186, 9, 0], [107, 121, 0], [0, 194, 160],
+ [255, 170, 146], [255, 144, 201], [185, 3, 170], [209, 97, 0],
+ [221, 239, 255], [0, 0, 53], [123, 79, 75], [161, 194, 153],
+ [48, 0, 24], [10, 166, 216], [1, 51, 73], [0, 132, 111],
+ [55, 33, 1], [255, 181, 0], [194, 255, 237], [160, 121, 191],
+ [204, 7, 68], [192, 185, 178], [194, 255, 153], [0, 30, 9],
+ [0, 72, 156], [111, 0, 98], [12, 189, 102], [238, 195, 255],
+ [69, 109, 117], [183, 123, 104], [122, 135, 161], [120, 141, 102],
+ [136, 85, 120], [250, 208, 159], [255, 138, 154], [209, 87, 160],
+ [190, 196, 89], [69, 102, 72], [0, 134, 237], [136, 111, 76],
+ [52, 54, 45], [180, 168, 189], [0, 166, 170], [69, 44, 44],
+ [99, 99, 117], [163, 200, 201], [255, 145, 63], [147, 138, 129],
+ [87, 83, 41], [0, 254, 207], [176, 91, 111], [140, 208, 255],
+ [59, 151, 0], [4, 247, 87], [200, 161, 161], [30, 110, 0],
+ [121, 0, 215], [167, 117, 0], [99, 103, 169], [160, 88, 55],
+ [107, 0, 44], [119, 38, 0], [215, 144, 255], [155, 151, 0],
+ [84, 158, 121], [255, 246, 159], [32, 22, 37], [114, 65, 143],
+ [188, 35, 255], [153, 173, 192], [58, 36, 101], [146, 35, 41],
+ [91, 69, 52], [253, 232, 220], [64, 78, 85], [0, 137, 163],
+ [203, 126, 152], [164, 232, 4], [50, 78, 114], [106, 58, 76],
+ [131, 171, 88], [0, 28, 30], [209, 247, 206], [0, 75, 40],
+ [200, 208, 246], [163, 164, 137], [128, 108, 102], [34, 40, 0],
+ [191, 86, 80], [232, 48, 0], [102, 121, 109], [218, 0, 124],
+ [255, 26, 89], [138, 219, 180], [30, 2, 0], [91, 78, 81],
+ [200, 149, 197], [50, 0, 51], [255, 104, 50], [102, 225, 211],
+ [207, 205, 172], [208, 172, 148], [126, 211, 121], [1, 44, 88],
+ [122, 123, 255], [214, 142, 1], [53, 51, 57], [120, 175, 161],
+ [254, 178, 198], [117, 121, 124], [131, 115, 147], [148, 58, 77],
+ [181, 244, 255], [210, 220, 213], [149, 86, 189], [106, 113, 74],
+ [0, 19, 37], [2, 82, 95], [10, 163, 247], [233, 129, 118],
+ [219, 213, 221], [94, 188, 209], [61, 79, 68], [126, 100, 5],
+ [2, 104, 78], [150, 43, 117], [141, 133, 70], [150, 149, 197],
+ [231, 115, 206], [216, 106, 120], [62, 137, 190], [202, 131, 78],
+ [81, 138, 135], [91, 17, 60], [85, 129, 59], [231, 4, 196],
+ [0, 0, 95], [169, 115, 153], [75, 129, 96], [89, 115, 138],
+ [255, 93, 167], [247, 201, 191], [100, 49, 39], [81, 58, 1],
+ [107, 148, 170], [81, 160, 88], [164, 91, 2], [29, 23, 2],
+ [226, 0, 39], [231, 171, 99], [76, 96, 1], [156, 105, 102],
+ [100, 84, 123], [151, 151, 158], [0, 106, 102], [57, 20, 6],
+ [244, 215, 73], [0, 69, 210], [0, 108, 49], [221, 182, 208],
+ [124, 101, 113], [159, 178, 164], [0, 216, 145], [21, 160, 138],
+ [188, 101, 233], [255, 255, 254], [198, 220, 153], [32, 59, 60],
+ [103, 17, 144], [107, 58, 100], [245, 225, 255], [255, 160, 242],
+ [204, 170, 53], [55, 69, 39], [139, 180, 0], [121, 120, 104],
+ [198, 0, 90], [59, 0, 10], [200, 98, 64], [41, 96, 124],
+ [64, 35, 52], [125, 90, 68], [204, 184, 124], [184, 129, 131],
+ [170, 81, 153], [181, 214, 195], [163, 132, 105], [159, 148, 240],
+ [167, 69, 113], [184, 148, 166], [113, 187, 140], [0, 180, 51],
+ [120, 158, 201], [109, 128, 186], [149, 63, 0], [94, 255, 3],
+ [228, 255, 252], [27, 225, 119], [188, 177, 229], [118, 145, 47],
+ [0, 49, 9], [0, 96, 205], [210, 0, 150], [137, 85, 99],
+ [41, 32, 29], [91, 50, 19], [167, 111, 66], [137, 65, 46],
+ [26, 58, 42], [73, 75, 90], [168, 140, 133], [244, 171, 170],
+ [163, 243, 171], [0, 198, 200], [234, 139, 102], [149, 138, 159],
+ [189, 201, 210], [159, 160, 100], [190, 71, 0], [101, 129, 136],
+ [131, 164, 133], [69, 60, 35], [71, 103, 93], [58, 63, 0],
+ [6, 18, 3], [223, 251, 113], [134, 142, 126], [152, 208, 88],
+ [108, 143, 125], [215, 191, 194], [60, 62, 110], [216, 61, 102],
+ [47, 93, 155], [108, 94, 70], [210, 91, 136], [91, 101, 108],
+ [0, 181, 127], [84, 92, 70], [134, 96, 151], [54, 93, 37],
+ [37, 47, 153], [0, 204, 255], [103, 78, 96], [252, 0, 156],
+ [146, 137, 107], [30, 35, 36], [222, 201, 178], [157, 73, 72],
+ [133, 171, 180], [52, 33, 66], [208, 150, 133], [164, 172, 172],
+ [0, 255, 255], [174, 156, 134], [116, 42, 51], [14, 114, 197],
+ [175, 216, 236], [192, 100, 185], [145, 2, 140], [254, 237, 191],
+ [255, 183, 137], [156, 184, 228], [175, 255, 209], [42, 54, 76],
+ [79, 74, 67], [100, 112, 149], [52, 187, 255], [128, 119, 129],
+ [146, 0, 3], [179, 165, 167], [1, 134, 21], [241, 255, 200],
+ [151, 111, 92], [255, 59, 193], [255, 95, 107], [7, 125, 132],
+ [245, 109, 147], [87, 113, 218], [78, 30, 42], [131, 0, 85],
+ [2, 211, 70], [190, 69, 45], [0, 144, 94], [190, 0, 40],
+ [110, 150, 227], [0, 118, 153], [254, 201, 109], [156, 106, 125],
+ [63, 161, 184], [137, 61, 227], [121, 180, 214], [127, 212, 217],
+ [103, 81, 187], [178, 141, 45], [226, 122, 5], [221, 156, 184],
+ [170, 188, 122], [152, 0, 52], [86, 26, 2], [143, 127, 0],
+ [99, 80, 0], [205, 125, 174], [138, 94, 45], [255, 179, 225],
+ [107, 100, 102], [198, 211, 0], [1, 0, 226], [136, 236, 105],
+ [143, 204, 190], [33, 0, 28], [81, 31, 77], [227, 246, 227],
+ [255, 142, 177], [107, 79, 41], [163, 127, 70], [106, 89, 80],
+ [31, 42, 26], [4, 120, 77], [16, 24, 53], [230, 224, 208],
+ [255, 116, 254], [0, 164, 95], [143, 93, 248], [75, 0, 89],
+ [65, 47, 35], [216, 147, 158], [219, 157, 114], [96, 65, 67],
+ [181, 186, 206], [152, 158, 183], [210, 196, 219], [165, 135, 175],
+ [119, 215, 150], [127, 140, 148], [255, 155, 3], [85, 81, 150],
+ [49, 221, 174], [116, 182, 113], [128, 38, 71], [42, 55, 63],
+ [1, 74, 104], [105, 102, 40], [76, 123, 109], [0, 44, 39],
+ [122, 69, 34], [59, 88, 89], [229, 211, 129], [255, 243, 255],
+ [103, 159, 160], [38, 19, 0], [44, 87, 66], [145, 49, 175],
+ [175, 93, 136], [199, 112, 106], [97, 171, 31], [140, 242, 212],
+ [197, 217, 184], [159, 255, 251], [191, 69, 204], [73, 57, 65],
+ [134, 59, 96], [185, 0, 118], [0, 49, 119], [197, 130, 210],
+ [193, 179, 148], [96, 43, 112], [136, 120, 104], [186, 191, 176],
+ [3, 0, 18], [209, 172, 254], [127, 222, 254], [75, 92, 113],
+ [163, 160, 151], [230, 109, 83], [99, 123, 93], [146, 190, 165],
+ [0, 248, 179], [190, 221, 255], [61, 181, 167], [221, 50, 72],
+ [182, 228, 222], [66, 119, 69], [89, 140, 90], [185, 76, 89],
+ [129, 129, 213], [148, 136, 139], [254, 214, 189], [83, 109, 49],
+ [110, 255, 146], [228, 232, 255], [32, 226, 0], [255, 208, 242],
+ [76, 131, 161], [189, 115, 34], [145, 92, 78], [140, 71, 135],
+ [2, 81, 23], [162, 170, 69], [45, 27, 33], [169, 221, 176],
+ [255, 79, 120], [82, 133, 0], [0, 154, 46], [23, 252, 228],
+ [113, 85, 90], [82, 93, 130], [0, 25, 90], [150, 120, 116],
+ [85, 85, 88], [11, 33, 44], [30, 32, 43], [239, 191, 196],
+ [111, 151, 85], [111, 117, 134], [80, 29, 29], [55, 45, 0],
+ [116, 29, 22], [94, 179, 147], [181, 180, 0], [221, 74, 56],
+ [54, 61, 255], [173, 101, 82], [102, 53, 175], [131, 107, 186],
+ [152, 170, 127], [70, 72, 54], [50, 44, 62], [124, 185, 186],
+ [91, 105, 101], [112, 125, 61], [122, 0, 29], [110, 70, 54],
+ [68, 58, 56], [174, 129, 255], [72, 144, 121], [137, 115, 52],
+ [0, 144, 135], [218, 113, 60], [54, 22, 24], [255, 111, 1],
+ [0, 102, 121], [55, 14, 119], [75, 58, 131], [201, 226, 230],
+ [196, 65, 112], [255, 69, 38], [115, 190, 84], [196, 223, 114],
+ [173, 255, 96], [0, 68, 125], [220, 206, 201], [189, 148, 121],
+ [101, 110, 91], [236, 82, 0], [255, 110, 194], [122, 97, 126],
+ [221, 174, 162], [119, 131, 127], [165, 51, 39], [96, 142, 255],
+ [181, 153, 215], [165, 1, 73], [78, 0, 37], [201, 177, 169],
+ [3, 145, 154], [27, 42, 37], [229, 0, 241], [152, 46, 11],
+ [182, 113, 128], [224, 88, 89], [0, 96, 57], [87, 143, 155],
+ [48, 82, 48], [206, 147, 76], [179, 194, 190], [192, 186, 192],
+ [181, 6, 211], [23, 12, 16], [76, 83, 79], [34, 68, 81],
+ [62, 65, 65], [120, 114, 109], [182, 96, 43], [32, 4, 65],
+ [221, 181, 136], [73, 114, 0], [197, 170, 182], [3, 60, 97],
+ [113, 178, 245], [169, 224, 136], [73, 121, 176], [162, 195, 223],
+ [120, 65, 73], [45, 43, 23], [62, 14, 47], [87, 52, 76],
+ [0, 145, 190], [228, 81, 209], [75, 75, 106], [92, 1, 26],
+ [124, 128, 96], [255, 148, 145], [76, 50, 93], [0, 92, 139],
+ [229, 253, 164], [104, 209, 182], [3, 38, 65], [20, 0, 35],
+ [134, 131, 169], [207, 255, 0], [167, 44, 62], [52, 71, 90],
+ [177, 187, 154], [180, 160, 79], [141, 145, 142], [161, 104, 166],
+ [129, 61, 58], [66, 82, 24], [218, 131, 134], [119, 97, 51],
+ [86, 57, 48], [132, 152, 174], [144, 193, 211], [181, 102, 107],
+ [155, 88, 94], [133, 100, 101], [173, 124, 144], [226, 188, 0],
+ [227, 170, 224], [178, 194, 254], [253, 0, 57], [0, 155, 117],
+ [255, 244, 109], [232, 126, 172], [223, 227, 230], [132, 133, 144],
+ [170, 146, 151], [131, 161, 147], [87, 121, 119], [62, 113, 88],
+ [198, 66, 137], [234, 0, 114], [196, 168, 203], [85, 200, 153],
+ [231, 143, 207], [0, 69, 71], [246, 226, 227], [150, 103, 22],
+ [55, 143, 219], [67, 94, 106], [218, 0, 4], [27, 0, 15],
+ [91, 156, 143], [110, 43, 82], [1, 17, 21], [227, 232, 196],
+ [174, 59, 133], [234, 28, 169], [255, 158, 107], [69, 125, 139],
+ [146, 103, 139], [0, 205, 187], [156, 204, 4], [0, 46, 56],
+ [150, 197, 127], [207, 246, 180], [73, 40, 24], [118, 110, 82],
+ [32, 55, 14], [227, 209, 159], [46, 60, 48], [178, 234, 206],
+ [243, 189, 164], [162, 78, 61], [151, 111, 217], [140, 159, 168],
+ [124, 43, 115], [78, 95, 55], [93, 84, 98], [144, 149, 111],
+ [106, 167, 118], [219, 203, 246], [218, 113, 255], [152, 124, 149],
+ [82, 50, 60], [187, 60, 66], [88, 77, 57], [79, 193, 95],
+ [162, 185, 193], [121, 219, 33], [29, 89, 88], [189, 116, 78],
+ [22, 11, 0], [32, 34, 26], [107, 130, 149], [0, 224, 228],
+ [16, 36, 1], [27, 120, 42], [218, 169, 181], [176, 65, 93],
+ [133, 146, 83], [151, 160, 148], [6, 227, 196], [71, 104, 140],
+ [124, 103, 85], [7, 92, 0], [117, 96, 213], [125, 159, 0],
+ [195, 109, 150], [77, 145, 62], [95, 66, 118], [252, 228, 200],
+ [48, 48, 82], [79, 56, 27], [229, 165, 50], [112, 102, 144],
+ [170, 154, 146], [35, 115, 99], [115, 1, 62], [255, 144, 121],
+ [167, 154, 116], [2, 155, 219], [255, 1, 105], [199, 210, 231],
+ [202, 136, 105], [128, 255, 205], [187, 31, 105], [144, 176, 171],
+ [125, 116, 169], [252, 199, 219], [153, 55, 91], [0, 171, 77],
+ [171, 174, 209], [190, 157, 145], [230, 229, 167], [51, 44, 34],
+ [221, 88, 123], [245, 255, 247], [93, 48, 51], [109, 56, 0],
+ [255, 0, 32], [181, 123, 179], [215, 255, 230], [197, 53, 169],
+ [38, 0, 9], [106, 135, 129], [168, 171, 180], [212, 82, 98],
+ [121, 75, 97], [70, 33, 178], [141, 164, 219], [199, 200, 144],
+ [111, 233, 173], [162, 67, 167], [178, 176, 129], [24, 27, 0],
+ [40, 97, 84], [76, 164, 59], [106, 149, 115], [168, 68, 29],
+ [92, 114, 123], [115, 134, 113], [208, 207, 203], [137, 123, 119],
+ [31, 63, 34], [65, 69, 167], [218, 152, 148], [161, 117, 122],
+ [99, 36, 60], [173, 170, 255], [0, 205, 226], [221, 188, 98],
+ [105, 142, 177], [32, 132, 98], [0, 183, 224], [97, 74, 68],
+ [155, 187, 87], [122, 92, 84], [133, 122, 80], [118, 107, 126],
+ [1, 72, 51], [255, 131, 71], [122, 142, 186], [39, 71, 64],
+ [148, 100, 68], [235, 216, 230], [100, 98, 65], [55, 57, 23],
+ [106, 212, 80], [129, 129, 123], [212, 153, 227], [151, 148, 64],
+ [1, 26, 18], [82, 101, 84], [181, 136, 92], [164, 153, 165],
+ [3, 173, 137], [179, 0, 139], [227, 196, 181], [150, 83, 31],
+ [134, 113, 117], [116, 86, 158], [97, 125, 159], [231, 4, 82],
+ [6, 126, 175], [166, 151, 182], [183, 135, 168], [156, 255, 147],
+ [49, 29, 25], [58, 148, 89], [110, 116, 110], [176, 197, 174],
+ [132, 237, 247], [237, 52, 136], [117, 76, 120], [56, 70, 68],
+ [199, 132, 123], [0, 182, 197], [127, 166, 112], [193, 175, 158],
+ [42, 127, 255], [114, 165, 140], [255, 192, 127], [157, 235, 221],
+ [217, 124, 142], [126, 124, 147], [98, 230, 116], [181, 99, 158],
+ [255, 168, 97], [194, 165, 128], [141, 156, 131], [183, 5, 70],
+ [55, 43, 46], [0, 152, 255], [152, 89, 117], [32, 32, 76],
+ [255, 108, 96], [68, 80, 131], [133, 2, 170], [114, 54, 31],
+ [150, 118, 163], [72, 68, 73], [206, 214, 194], [59, 22, 74],
+ [204, 167, 99], [44, 127, 119], [2, 34, 123], [163, 126, 111],
+ [205, 230, 220], [205, 255, 251], [190, 129, 26], [247, 113, 131],
+ [237, 230, 226], [205, 198, 180], [255, 224, 158], [58, 114, 113],
+ [255, 123, 89], [78, 78, 1], [74, 198, 132], [139, 200, 145],
+ [188, 138, 150], [207, 99, 83], [220, 222, 92], [94, 170, 221],
+ [246, 160, 173], [226, 105, 170], [163, 218, 228], [67, 110, 131],
+ [0, 46, 23], [236, 251, 255], [161, 194, 182], [80, 0, 63],
+ [113, 105, 91], [103, 196, 187], [83, 110, 255], [93, 90, 72],
+ [137, 0, 57], [150, 147, 129], [55, 21, 33], [94, 70, 101],
+ [170, 98, 195], [141, 111, 129], [44, 97, 53], [65, 6, 1],
+ [86, 70, 32], [230, 144, 52], [109, 166, 189], [229, 142, 86],
+ [227, 166, 139], [72, 177, 118], [210, 125, 103], [181, 178, 104],
+ [127, 132, 39], [255, 132, 230], [67, 87, 64], [234, 228, 8],
+ [244, 245, 255], [50, 88, 0], [75, 107, 165], [173, 206, 255],
+ [155, 138, 204], [136, 81, 56], [88, 117, 193], [126, 115, 17],
+ [254, 165, 202], [159, 139, 91], [165, 91, 84], [137, 0, 106],
+ [175, 117, 111], [42, 32, 0], [116, 153, 161], [255, 181, 80],
+ [0, 1, 30], [209, 81, 28], [104, 129, 81], [188, 144, 138],
+ [120, 200, 235], [133, 2, 255], [72, 61, 48], [196, 34, 33],
+ [94, 167, 255], [120, 87, 21], [12, 234, 145], [255, 250, 237],
+ [179, 175, 157], [62, 61, 82], [90, 155, 194], [156, 47, 144],
+ [141, 87, 0], [173, 215, 156], [0, 118, 139], [51, 125, 0],
+ [197, 151, 0], [49, 86, 220], [148, 69, 117], [236, 255, 220],
+ [210, 76, 178], [151, 112, 60], [76, 37, 127], [158, 3, 102],
+ [136, 255, 236], [181, 100, 129], [57, 109, 43], [86, 115, 95],
+ [152, 131, 118], [155, 177, 149], [169, 121, 92], [228, 197, 211],
+ [159, 79, 103], [30, 43, 57], [102, 67, 39], [175, 206, 120],
+ [50, 46, 223], [134, 180, 135], [194, 48, 0], [171, 232, 107],
+ [150, 101, 109], [37, 14, 53], [166, 0, 25], [0, 128, 207],
+ [202, 239, 255], [50, 63, 97], [164, 73, 220], [106, 157, 59],
+ [255, 90, 228], [99, 106, 1], [209, 108, 218], [115, 96, 96],
+ [255, 186, 173], [211, 105, 180], [255, 222, 214], [108, 109, 116],
+ [146, 125, 94], [132, 93, 112], [91, 98, 193], [47, 74, 54],
+ [228, 95, 53], [255, 59, 83], [172, 132, 221], [118, 41, 136],
+ [112, 236, 152], [64, 133, 67], [44, 53, 51], [46, 24, 45],
+ [50, 57, 37], [25, 24, 27], [47, 46, 44], [2, 60, 50],
+ [155, 158, 226], [88, 175, 173], [92, 66, 77], [122, 197, 166],
+ [104, 93, 117], [185, 188, 189], [131, 67, 87], [26, 123, 66],
+ [46, 87, 170], [229, 81, 153], [49, 110, 71], [205, 0, 197],
+ [106, 0, 77], [127, 187, 236], [243, 86, 145], [215, 197, 74],
+ [98, 172, 183], [203, 161, 188], [162, 138, 154], [108, 63, 59],
+ [255, 228, 125], [220, 186, 227], [95, 129, 109], [58, 64, 74],
+ [125, 191, 50], [230, 236, 220], [133, 44, 25], [40, 83, 102],
+ [184, 203, 156], [14, 13, 0], [75, 93, 86], [107, 84, 63],
+ [226, 113, 114], [5, 104, 236], [46, 181, 0], [210, 22, 86],
+ [239, 175, 255], [104, 32, 33], [45, 32, 17], [218, 76, 255],
+ [112, 150, 142], [255, 123, 125], [74, 25, 48], [232, 194, 130],
+ [231, 219, 188], [166, 132, 134], [31, 38, 60], [54, 87, 78],
+ [82, 206, 121], [173, 170, 169], [138, 159, 69], [101, 66, 210],
+ [0, 251, 140], [93, 105, 123], [204, 210, 127], [148, 165, 161],
+ [121, 2, 41], [227, 131, 230], [126, 164, 193], [78, 68, 82],
+ [75, 44, 0], [98, 11, 112], [49, 76, 30], [135, 74, 166],
+ [227, 0, 145], [102, 70, 10], [235, 154, 139], [234, 195, 163],
+ [152, 234, 179], [171, 145, 128], [184, 85, 47], [26, 43, 47],
+ [148, 221, 197], [157, 140, 118], [156, 131, 51], [148, 169, 201],
+ [57, 41, 53], [140, 103, 94], [204, 233, 58], [145, 113, 0],
+ [1, 64, 11], [68, 152, 150], [28, 163, 112], [224, 141, 167],
+ [139, 74, 78], [102, 119, 118], [70, 146, 173], [103, 189, 168],
+ [105, 37, 92], [211, 191, 255], [74, 81, 50], [126, 146, 133],
+ [119, 115, 60], [231, 160, 204], [81, 162, 136], [44, 101, 106],
+ [77, 92, 94], [201, 64, 58], [221, 215, 243], [0, 88, 68],
+ [180, 162, 0], [72, 143, 105], [133, 129, 130], [212, 233, 185],
+ [61, 115, 151], [202, 232, 206], [214, 0, 52], [170, 103, 70],
+ [158, 85, 133], [186, 98, 0]
+]
+
+high_contrast_arr = numpy.array(high_contrast, dtype=numpy.uint8)
diff --git a/netdissect/server.py b/netdissect/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8422a2bad5ac2a09d4582a98da4f962dac1a911
--- /dev/null
+++ b/netdissect/server.py
@@ -0,0 +1,185 @@
+#!/usr/bin/env python
+
+import argparse, connexion, os, sys, yaml, json, socket
+from netdissect.easydict import EasyDict
+from flask import send_from_directory, redirect
+from flask_cors import CORS
+
+
+from netdissect.serverstate import DissectionProject
+
+__author__ = 'Hendrik Strobelt, David Bau'
+
+CONFIG_FILE_NAME = 'dissect.json'
+projects = {}
+
+app = connexion.App(__name__, debug=False)
+
+
+def get_all_projects():
+ res = []
+ for key, project in projects.items():
+ # print key
+ res.append({
+ 'project': key,
+ 'info': {
+ 'layers': [layer['layer'] for layer in project.get_layers()]
+ }
+ })
+ return sorted(res, key=lambda x: x['project'])
+
+def get_layers(project):
+ return {
+ 'request': {'project': project},
+ 'res': projects[project].get_layers()
+ }
+
+def get_units(project, layer):
+ return {
+ 'request': {'project': project, 'layer': layer},
+ 'res': projects[project].get_units(layer)
+ }
+
+def get_rankings(project, layer):
+ return {
+ 'request': {'project': project, 'layer': layer},
+ 'res': projects[project].get_rankings(layer)
+ }
+
+def get_levels(project, layer, quantiles):
+ return {
+ 'request': {'project': project, 'layer': layer, 'quantiles': quantiles},
+ 'res': projects[project].get_levels(layer, quantiles)
+ }
+
+def get_channels(project, layer):
+ answer = dict(channels=projects[project].get_channels(layer))
+ return {
+ 'request': {'project': project, 'layer': layer},
+ 'res': answer
+ }
+
+def post_generate(gen_req):
+ project = gen_req['project']
+ zs = gen_req.get('zs', None)
+ ids = gen_req.get('ids', None)
+ return_urls = gen_req.get('return_urls', False)
+ assert (zs is None) != (ids is None) # one or the other, not both
+ ablations = gen_req.get('ablations', [])
+ interventions = gen_req.get('interventions', None)
+ # no z avilable if ablations
+ generated = projects[project].generate_images(zs, ids, interventions,
+ return_urls=return_urls)
+ return {
+ 'request': gen_req,
+ 'res': generated
+ }
+
+def post_features(feat_req):
+ project = feat_req['project']
+ ids = feat_req['ids']
+ masks = feat_req.get('masks', None)
+ layers = feat_req.get('layers', None)
+ interventions = feat_req.get('interventions', None)
+ features = projects[project].get_features(
+ ids, masks, layers, interventions)
+ return {
+ 'request': feat_req,
+ 'res': features
+ }
+
+def post_featuremaps(feat_req):
+ project = feat_req['project']
+ ids = feat_req['ids']
+ layers = feat_req.get('layers', None)
+ interventions = feat_req.get('interventions', None)
+ featuremaps = projects[project].get_featuremaps(
+ ids, layers, interventions)
+ return {
+ 'request': feat_req,
+ 'res': featuremaps
+ }
+
+@app.route('/client/')
+def send_static(path):
+ """ serves all files from ./client/ to ``/client/``
+
+ :param path: path from api call
+ """
+ return send_from_directory(args.client, path)
+
+@app.route('/data/')
+def send_data(path):
+ """ serves all files from the data dir to ``/dissect/``
+
+ :param path: path from api call
+ """
+ print('Got the data route for', path)
+ return send_from_directory(args.data, path)
+
+
+@app.route('/')
+def redirect_home():
+ return redirect('/client/index.html', code=302)
+
+
+def load_projects(directory):
+ """
+ searches for CONFIG_FILE_NAME in all subdirectories of directory
+ and creates data handlers for all of them
+
+ :param directory: scan directory
+ :return: null
+ """
+ project_dirs = []
+ # Don't search more than 2 dirs deep.
+ search_depth = 2 + directory.count(os.path.sep)
+ for root, dirs, files in os.walk(directory):
+ if CONFIG_FILE_NAME in files:
+ project_dirs.append(root)
+ # Don't get subprojects under a project dir.
+ del dirs[:]
+ elif root.count(os.path.sep) >= search_depth:
+ del dirs[:]
+ for p_dir in project_dirs:
+ print('Loading %s' % os.path.join(p_dir, CONFIG_FILE_NAME))
+ with open(os.path.join(p_dir, CONFIG_FILE_NAME), 'r') as jf:
+ config = EasyDict(json.load(jf))
+ dh_id = os.path.split(p_dir)[1]
+ projects[dh_id] = DissectionProject(
+ config=config,
+ project_dir=p_dir,
+ path_url='data/' + os.path.relpath(p_dir, directory),
+ public_host=args.public_host)
+
+app.add_api('server.yaml')
+
+# add CORS support
+CORS(app.app, headers='Content-Type')
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--nodebug", default=False)
+parser.add_argument("--address", default="127.0.0.1") # 0.0.0.0 for nonlocal use
+parser.add_argument("--port", default="5001")
+parser.add_argument("--public_host", default=None)
+parser.add_argument("--nocache", default=False)
+parser.add_argument("--data", type=str, default='dissect')
+parser.add_argument("--client", type=str, default='client_dist')
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ for d in [args.data, args.client]:
+ if not os.path.isdir(d):
+ print('No directory %s' % d)
+ sys.exit(1)
+ args.data = os.path.abspath(args.data)
+ args.client = os.path.abspath(args.client)
+ if args.public_host is None:
+ args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port))
+ app.run(port=int(args.port), debug=not args.nodebug, host=args.address,
+ use_reloader=False)
+else:
+ args, _ = parser.parse_known_args()
+ if args.public_host is None:
+ args.public_host = '%s:%d' % (socket.getfqdn(), int(args.port))
+ load_projects(args.data)
diff --git a/netdissect/server.yaml b/netdissect/server.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9e67b9bbcb24397a21623009b4b6bf0e6d4c9193
--- /dev/null
+++ b/netdissect/server.yaml
@@ -0,0 +1,300 @@
+swagger: '2.0'
+info:
+ title: Ganter API
+ version: "0.1"
+consumes:
+ - application/json
+produces:
+ - application/json
+
+basePath: /api
+
+paths:
+ /all_projects:
+ get:
+ tags:
+ - all
+ summary: information about all projects and sources available
+ operationId: netdissect.server.get_all_projects
+ responses:
+ 200:
+ description: return list of projects
+ schema:
+ type: array
+ items:
+ type: object
+
+ /layers:
+ get:
+ operationId: netdissect.server.get_layers
+ tags:
+ - all
+ summary: returns information about all layers
+ parameters:
+ - $ref: '#/parameters/project'
+ responses:
+ 200:
+ description: Return requested data
+ schema:
+ type: object
+
+ /units:
+ get:
+ operationId: netdissect.server.get_units
+ tags:
+ - all
+ summary: returns unit information for one layer
+ parameters:
+
+ - $ref: '#/parameters/project'
+ - $ref: '#/parameters/layer'
+
+ responses:
+ 200:
+ description: Return requested data
+ schema:
+ type: object
+
+ /rankings:
+ get:
+ operationId: netdissect.server.get_rankings
+ tags:
+ - all
+ summary: returns ranking information for one layer
+ parameters:
+
+ - $ref: '#/parameters/project'
+ - $ref: '#/parameters/layer'
+
+ responses:
+ 200:
+ description: Return requested data
+ schema:
+ type: object
+
+ /levels:
+ get:
+ operationId: netdissect.server.get_levels
+ tags:
+ - all
+ summary: returns feature levels for one layer
+ parameters:
+
+ - $ref: '#/parameters/project'
+ - $ref: '#/parameters/layer'
+ - $ref: '#/parameters/quantiles'
+
+ responses:
+ 200:
+ description: Return requested data
+ schema:
+ type: object
+
+ /features:
+ post:
+ summary: calculates max feature values within a set of image locations
+ operationId: netdissect.server.post_features
+ tags:
+ - all
+ parameters:
+ - in: body
+ name: feat_req
+ description: RequestObject
+ schema:
+ $ref: "#/definitions/FeatureRequest"
+ responses:
+ 200:
+ description: returns feature vector for each layer
+
+ /featuremaps:
+ post:
+ summary: calculates max feature values within a set of image locations
+ operationId: netdissect.server.post_featuremaps
+ tags:
+ - all
+ parameters:
+ - in: body
+ name: feat_req
+ description: RequestObject
+ schema:
+ $ref: "#/definitions/FeatureMapRequest"
+ responses:
+ 200:
+ description: returns feature vector for each layer
+
+ /channels:
+ get:
+ operationId: netdissect.server.get_channels
+ tags:
+ - all
+ summary: returns channel information
+ parameters:
+
+ - $ref: '#/parameters/project'
+ - $ref: '#/parameters/layer'
+
+ responses:
+ 200:
+ description: Return requested data
+ schema:
+ type: object
+
+ /generate:
+ post:
+ summary: generates images for given zs constrained by ablation
+ operationId: netdissect.server.post_generate
+ tags:
+ - all
+ parameters:
+ - in: body
+ name: gen_req
+ description: RequestObject
+ schema:
+ $ref: "#/definitions/GenerateRequest"
+ responses:
+ 200:
+ description: aaa
+
+
+parameters:
+ project:
+ name: project
+ description: project ID
+ in: query
+ required: true
+ type: string
+
+ layer:
+ name: layer
+ description: layer ID
+ in: query
+ type: string
+ default: "1"
+
+ quantiles:
+ name: quantiles
+ in: query
+ type: array
+ items:
+ type: number
+ format: float
+
+definitions:
+
+ GenerateRequest:
+ type: object
+ required:
+ - project
+ properties:
+ project:
+ type: string
+ zs:
+ type: array
+ items:
+ type: array
+ items:
+ type: number
+ format: float
+ ids:
+ type: array
+ items:
+ type: integer
+ return_urls:
+ type: integer
+ interventions:
+ type: array
+ items:
+ - $ref: '#/definitions/Intervention'
+
+ FeatureRequest:
+ type: object
+ required:
+ - project
+ properties:
+ project:
+ type: string
+ example: 'churchoutdoor'
+ layers:
+ type: array
+ items:
+ type: string
+ example: [ 'layer5' ]
+ ids:
+ type: array
+ items:
+ type: integer
+ masks:
+ type: array
+ items:
+ - $ref: '#/definitions/Mask'
+ interventions:
+ type: array
+ items:
+ - $ref: '#/definitions/Intervention'
+
+ FeatureMapRequest:
+ type: object
+ required:
+ - project
+ properties:
+ project:
+ type: string
+ example: 'churchoutdoor'
+ layers:
+ type: array
+ items:
+ type: string
+ example: [ 'layer5' ]
+ ids:
+ type: array
+ items:
+ type: integer
+ interventions:
+ type: array
+ items:
+ - $ref: '#/definitions/Intervention'
+
+ Intervention:
+ type: object
+ properties:
+ maskalpha:
+ $ref: '#/definitions/Mask'
+ maskvalue:
+ $ref: '#/definitions/Mask'
+ ablations:
+ type: array
+ items:
+ - $ref: '#/definitions/Ablation'
+
+ Ablation:
+ type: object
+ properties:
+ unit:
+ type: integer
+ alpha:
+ type: number
+ format: float
+ value:
+ type: number
+ format: float
+ layer:
+ type: string
+
+ Mask:
+ type: object
+ description: 2d bitmap mask
+ properties:
+ shape:
+ type: array
+ items:
+ type: integer
+ example: [ 128, 128 ]
+ bitbounds:
+ type: array
+ items:
+ type: integer
+ example: [ 12, 42, 16, 46 ]
+ bitstring:
+ type: string
+ example: '0110111111110011'
+
diff --git a/netdissect/serverstate.py b/netdissect/serverstate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7ddc790c3dfc881f8aa4322d10d90e4e4fc09f0
--- /dev/null
+++ b/netdissect/serverstate.py
@@ -0,0 +1,526 @@
+import os, torch, numpy, base64, json, re, threading, random
+from torch.utils.data import TensorDataset, DataLoader
+from collections import defaultdict
+from netdissect.easydict import EasyDict
+from netdissect.modelconfig import create_instrumented_model
+from netdissect.runningstats import RunningQuantile
+from netdissect.dissection import safe_dir_name
+from netdissect.zdataset import z_sample_for_model
+from PIL import Image
+from io import BytesIO
+
+class DissectionProject:
+ '''
+ DissectionProject understand how to drive a GanTester within a
+ dissection project directory structure: it caches data in files,
+ creates image files, and translates data between plain python data
+ types and the pytorch-specific tensors required by GanTester.
+ '''
+ def __init__(self, config, project_dir, path_url, public_host):
+ print('config done', project_dir)
+ self.use_cuda = torch.cuda.is_available()
+ self.dissect = config
+ self.project_dir = project_dir
+ self.path_url = path_url
+ self.public_host = public_host
+ self.cachedir = os.path.join(self.project_dir, 'cache')
+ self.tester = GanTester(
+ config.settings, dissectdir=project_dir,
+ device=torch.device('cuda') if self.use_cuda
+ else torch.device('cpu'))
+ self.stdz = []
+
+ def get_zs(self, size):
+ if size <= len(self.stdz):
+ return self.stdz[:size].tolist()
+ z_tensor = self.tester.standard_z_sample(size)
+ numpy_z = z_tensor.cpu().numpy()
+ self.stdz = numpy_z
+ return self.stdz.tolist()
+
+ def get_z(self, id):
+ if id < len(self.stdz):
+ return self.stdz[id]
+ return self.get_zs((id + 1) * 2)[id]
+
+ def get_zs_for_ids(self, ids):
+ max_id = max(ids)
+ if max_id >= len(self.stdz):
+ self.get_z(max_id)
+ return self.stdz[ids]
+
+ def get_layers(self):
+ result = []
+ layer_shapes = self.tester.layer_shapes()
+ for layer in self.tester.layers:
+ shape = layer_shapes[layer]
+ result.append(dict(
+ layer=layer,
+ channels=shape[1],
+ shape=[shape[2], shape[3]]))
+ return result
+
+ def get_units(self, layer):
+ try:
+ dlayer = [dl for dl in self.dissect['layers']
+ if dl['layer'] == layer][0]
+ except:
+ return None
+
+ dunits = dlayer['units']
+ result = [dict(unit=unit_num,
+ img='/%s/%s/s-image/%d-top.jpg' %
+ (self.path_url, layer, unit_num),
+ label=unit['iou_label'])
+ for unit_num, unit in enumerate(dunits)]
+ return result
+
+ def get_rankings(self, layer):
+ try:
+ dlayer = [dl for dl in self.dissect['layers']
+ if dl['layer'] == layer][0]
+ except:
+ return None
+ result = [dict(name=ranking['name'],
+ metric=ranking.get('metric', None),
+ scores=ranking['score'])
+ for ranking in dlayer['rankings']]
+ return result
+
+ def get_levels(self, layer, quantiles):
+ levels = self.tester.levels(
+ layer, torch.from_numpy(numpy.array(quantiles)))
+ return levels.cpu().numpy().tolist()
+
+ def generate_images(self, zs, ids, interventions, return_urls=False):
+ if ids is not None:
+ assert zs is None
+ zs = self.get_zs_for_ids(ids)
+ if not interventions:
+ # Do file caching when ids are given (and no ablations).
+ imgdir = os.path.join(self.cachedir, 'img', 'id')
+ os.makedirs(imgdir, exist_ok=True)
+ exist = set(os.listdir(imgdir))
+ unfinished = [('%d.jpg' % id) not in exist for id in ids]
+ needed_z_tensor = torch.tensor(zs[unfinished]).float().to(
+ self.tester.device)
+ needed_ids = numpy.array(ids)[unfinished]
+ # Generate image files for just the needed images.
+ if len(needed_z_tensor):
+ imgs = self.tester.generate_images(needed_z_tensor
+ ).cpu().numpy()
+ for i, img in zip(needed_ids, imgs):
+ Image.fromarray(img.transpose(1, 2, 0)).save(
+ os.path.join(imgdir, '%d.jpg' % i), 'jpeg',
+ quality=99, optimize=True, progressive=True)
+ # Assemble a response.
+ imgurls = ['/%s/cache/img/id/%d.jpg'
+ % (self.path_url, i) for i in ids]
+ return [dict(id=i, d=d) for i, d in zip(ids, imgurls)]
+ # No file caching when ids are not given (or ablations are applied)
+ z_tensor = torch.tensor(zs).float().to(self.tester.device)
+ imgs = self.tester.generate_images(z_tensor,
+ intervention=decode_intervention_array(interventions,
+ self.tester.layer_shapes()),
+ ).cpu().numpy()
+ numpy_z = z_tensor.cpu().numpy()
+ if return_urls:
+ randdir = '%03d' % random.randrange(1000)
+ imgdir = os.path.join(self.cachedir, 'img', 'uniq', randdir)
+ os.makedirs(imgdir, exist_ok=True)
+ startind = random.randrange(100000)
+ imgurls = []
+ for i, img in enumerate(imgs):
+ filename = '%d.jpg' % (i + startind)
+ Image.fromarray(img.transpose(1, 2, 0)).save(
+ os.path.join(imgdir, filename), 'jpeg',
+ quality=99, optimize=True, progressive=True)
+ image_url_path = ('/%s/cache/img/uniq/%s/%s'
+ % (self.path_url, randdir, filename))
+ imgurls.append(image_url_path)
+ tweet_filename = 'tweet-%d.html' % (i + startind)
+ tweet_url_path = ('/%s/cache/img/uniq/%s/%s'
+ % (self.path_url, randdir, tweet_filename))
+ with open(os.path.join(imgdir, tweet_filename), 'w') as f:
+ f.write(twitter_card(image_url_path, tweet_url_path,
+ self.public_host))
+ return [dict(d=d) for d in imgurls]
+ imgurls = [img2base64(img.transpose(1, 2, 0)) for img in imgs]
+ return [dict(d=d) for d in imgurls]
+
+ def get_features(self, ids, masks, layers, interventions):
+ zs = self.get_zs_for_ids(ids)
+ z_tensor = torch.tensor(zs).float().to(self.tester.device)
+ t_masks = torch.stack(
+ [torch.from_numpy(mask_to_numpy(mask)) for mask in masks]
+ )[:,None,:,:].to(self.tester.device)
+ t_features = self.tester.feature_stats(z_tensor, t_masks,
+ decode_intervention_array(interventions,
+ self.tester.layer_shapes()), layers)
+ # Convert torch arrays to plain python lists before returning.
+ return { layer: { key: value.cpu().numpy().tolist()
+ for key, value in feature.items() }
+ for layer, feature in t_features.items() }
+
+ def get_featuremaps(self, ids, layers, interventions):
+ zs = self.get_zs_for_ids(ids)
+ z_tensor = torch.tensor(zs).float().to(self.tester.device)
+ # Quantilized features are returned.
+ q_features = self.tester.feature_maps(z_tensor,
+ decode_intervention_array(interventions,
+ self.tester.layer_shapes()), layers)
+ # Scale them 0-255 and return them.
+ # TODO: turn them into pngs for returning.
+ return { layer: [
+ value.clamp(0, 1).mul(255).byte().cpu().numpy().tolist()
+ for value in valuelist ]
+ for layer, valuelist in q_features.items()
+ if (not layers) or (layer in layers) }
+
+ def get_recipes(self):
+ recipedir = os.path.join(self.project_dir, 'recipe')
+ if not os.path.isdir(recipedir):
+ return []
+ result = []
+ for filename in os.listdir(recipedir):
+ with open(os.path.join(recipedir, filename)) as f:
+ result.append(json.load(f))
+ return result
+
+
+
+
+class GanTester:
+ '''
+ GanTester holds on to a specific model to test.
+
+ (1) loads and instantiates the GAN;
+ (2) instruments it at every layer so that units can be ablated
+ (3) precomputes z dimensionality, and output image dimensions.
+ '''
+ def __init__(self, args, dissectdir=None, device=None):
+ self.cachedir = os.path.join(dissectdir, 'cache')
+ self.device = device if device is not None else torch.device('cpu')
+ self.dissectdir = dissectdir
+ self.modellock = threading.Lock()
+
+ # Load the generator from the pth file.
+ args_copy = EasyDict(args)
+ args_copy.edit = True
+ model = create_instrumented_model(args_copy)
+ model.eval()
+ self.model = model
+
+ # Get the set of layers of interest.
+ # Default: all shallow children except last.
+ self.layers = sorted(model.retained_features().keys())
+
+ # Move it to CUDA if wanted.
+ model.to(device)
+
+ self.quantiles = {
+ layer: load_quantile_if_present(os.path.join(self.dissectdir,
+ safe_dir_name(layer)), 'quantiles.npz',
+ device=torch.device('cpu'))
+ for layer in self.layers }
+
+ def layer_shapes(self):
+ return self.model.feature_shape
+
+ def standard_z_sample(self, size=100, seed=1, device=None):
+ '''
+ Generate a standard set of random Z as a (size, z_dimension) tensor.
+ With the same random seed, it always returns the same z (e.g.,
+ the first one is always the same regardless of the size.)
+ '''
+ result = z_sample_for_model(self.model, size)
+ if device is not None:
+ result = result.to(device)
+ return result
+
+ def reset_intervention(self):
+ self.model.remove_edits()
+
+ def apply_intervention(self, intervention):
+ '''
+ Applies an ablation recipe of the form [(layer, unit, alpha)...].
+ '''
+ self.reset_intervention()
+ if not intervention:
+ return
+ for layer, (a, v) in intervention.items():
+ self.model.edit_layer(layer, ablation=a, replacement=v)
+
+ def generate_images(self, z_batch, intervention=None):
+ '''
+ Makes some images.
+ '''
+ with torch.no_grad(), self.modellock:
+ batch_size = 10
+ self.apply_intervention(intervention)
+ test_loader = DataLoader(TensorDataset(z_batch[:,:,None,None]),
+ batch_size=batch_size,
+ pin_memory=('cuda' == self.device.type
+ and z_batch.device.type == 'cpu'))
+ result_img = torch.zeros(
+ *((len(z_batch), 3) + self.model.output_shape[2:]),
+ dtype=torch.uint8, device=self.device)
+ for batch_num, [batch_z,] in enumerate(test_loader):
+ batch_z = batch_z.to(self.device)
+ out = self.model(batch_z)
+ result_img[batch_num*batch_size:
+ batch_num*batch_size+len(batch_z)] = (
+ (((out + 1) / 2) * 255).clamp(0, 255).byte())
+ return result_img
+
+ def get_layers(self):
+ return self.layers
+
+ def feature_stats(self, z_batch,
+ masks=None, intervention=None, layers=None):
+ feature_stat = defaultdict(dict)
+ with torch.no_grad(), self.modellock:
+ batch_size = 10
+ self.apply_intervention(intervention)
+ if masks is None:
+ masks = torch.ones(z_batch.size(0), 1, 1, 1,
+ device=z_batch.device, dtype=z_batch.dtype)
+ else:
+ assert masks.shape[0] == z_batch.shape[0]
+ assert masks.shape[1] == 1
+ test_loader = DataLoader(
+ TensorDataset(z_batch[:,:,None,None], masks),
+ batch_size=batch_size,
+ pin_memory=('cuda' == self.device.type
+ and z_batch.device.type == 'cpu'))
+ processed = 0
+ for batch_num, [batch_z, batch_m] in enumerate(test_loader):
+ batch_z, batch_m = [
+ d.to(self.device) for d in [batch_z, batch_m]]
+ # Run model but disregard output
+ self.model(batch_z)
+ processing = batch_z.shape[0]
+ for layer, feature in self.model.retained_features().items():
+ if layers is not None:
+ if layer not in layers:
+ continue
+ # Compute max features touching mask
+ resized_max = torch.nn.functional.adaptive_max_pool2d(
+ batch_m,
+ (feature.shape[2], feature.shape[3]))
+ max_feature = (feature * resized_max).view(
+ feature.shape[0], feature.shape[1], -1
+ ).max(2)[0].max(0)[0]
+ if 'max' not in feature_stat[layer]:
+ feature_stat[layer]['max'] = max_feature
+ else:
+ torch.max(feature_stat[layer]['max'], max_feature,
+ out=feature_stat[layer]['max'])
+ # Compute mean features weighted by overlap with mask
+ resized_mean = torch.nn.functional.adaptive_avg_pool2d(
+ batch_m,
+ (feature.shape[2], feature.shape[3]))
+ mean_feature = (feature * resized_mean).view(
+ feature.shape[0], feature.shape[1], -1
+ ).sum(2).sum(0) / (resized_mean.sum() + 1e-15)
+ if 'mean' not in feature_stat[layer]:
+ feature_stat[layer]['mean'] = mean_feature
+ else:
+ feature_stat[layer]['mean'] = (
+ processed * feature_mean[layer]['mean']
+ + processing * mean_feature) / (
+ processed + processing)
+ processed += processing
+ # After summaries are done, also compute quantile stats
+ for layer, stats in feature_stat.items():
+ if self.quantiles.get(layer, None) is not None:
+ for statname in ['max', 'mean']:
+ stats['%s_quantile' % statname] = (
+ self.quantiles[layer].normalize(stats[statname]))
+ return feature_stat
+
+ def levels(self, layer, quantiles):
+ return self.quantiles[layer].quantiles(quantiles)
+
+ def feature_maps(self, z_batch, intervention=None, layers=None,
+ quantiles=True):
+ feature_map = defaultdict(list)
+ with torch.no_grad(), self.modellock:
+ batch_size = 10
+ self.apply_intervention(intervention)
+ test_loader = DataLoader(
+ TensorDataset(z_batch[:,:,None,None]),
+ batch_size=batch_size,
+ pin_memory=('cuda' == self.device.type
+ and z_batch.device.type == 'cpu'))
+ processed = 0
+ for batch_num, [batch_z] in enumerate(test_loader):
+ batch_z = batch_z.to(self.device)
+ # Run model but disregard output
+ self.model(batch_z)
+ processing = batch_z.shape[0]
+ for layer, feature in self.model.retained_features().items():
+ for single_featuremap in feature:
+ if quantiles:
+ feature_map[layer].append(self.quantiles[layer]
+ .normalize(single_featuremap))
+ else:
+ feature_map[layer].append(single_featuremap)
+ return feature_map
+
+def load_quantile_if_present(outdir, filename, device):
+ filepath = os.path.join(outdir, filename)
+ if os.path.isfile(filepath):
+ data = numpy.load(filepath)
+ result = RunningQuantile(state=data)
+ result.to_(device)
+ return result
+ return None
+
+if __name__ == '__main__':
+ test_main()
+
+def mask_to_numpy(mask_record):
+ # Detect a png image mask.
+ bitstring = mask_record['bitstring']
+ bitnumpy = None
+ default_shape = (256, 256)
+ if 'image/png;base64,' in bitstring:
+ bitnumpy = base642img(bitstring)
+ default_shape = bitnumpy.shape[:2]
+ # Set up results
+ shape = mask_record.get('shape', None)
+ if not shape: # None or empty []
+ shape = default_shape
+ result = numpy.zeros(shape=shape, dtype=numpy.float32)
+ bitbounds = mask_record.get('bitbounds', None)
+ if not bitbounds: # None or empty []
+ bitbounds = ([0] * len(result.shape)) + list(result.shape)
+ start = bitbounds[:len(result.shape)]
+ end = bitbounds[len(result.shape):]
+ if bitnumpy is not None:
+ if bitnumpy.shape[2] == 4:
+ # Mask is any nontransparent bits in the alpha channel if present
+ result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,3] > 0)
+ else:
+ # Or any nonwhite pixels in the red channel if no alpha.
+ result[start[0]:end[0], start[1]:end[1]] = (bitnumpy[:,:,0] < 255)
+ return result
+ else:
+ # Or bitstring can be just ones and zeros.
+ indexes = start.copy()
+ bitindex = 0
+ while True:
+ result[tuple(indexes)] = (bitstring[bitindex] != '0')
+ for ii in range(len(indexes) - 1, -1, -1):
+ if indexes[ii] < end[ii] - 1:
+ break
+ indexes[ii] = start[ii]
+ else:
+ assert (bitindex + 1) == len(bitstring)
+ return result
+ indexes[ii] += 1
+ bitindex += 1
+
+def decode_intervention_array(interventions, layer_shapes):
+ result = {}
+ for channels in [decode_intervention(intervention, layer_shapes)
+ for intervention in (interventions or [])]:
+ for layer, channel in channels.items():
+ if layer not in result:
+ result[layer] = channel
+ continue
+ accum = result[layer]
+ newalpha = 1 - (1 - channel[:1]) * (1 - accum[:1])
+ newvalue = (accum[1:] * accum[:1] * (1 - channel[:1]) +
+ channel[1:] * channel[:1]) / (newalpha + 1e-40)
+ accum[:1] = newalpha
+ accum[1:] = newvalue
+ return result
+
+def decode_intervention(intervention, layer_shapes):
+ # Every plane of an intervention is a solid choice of activation
+ # over a set of channels, with a mask applied to alpha-blended channels
+ # (when the mask resolution is different from the feature map, it can
+ # be either a max-pooled or average-pooled to the proper resolution).
+ # This can be reduced to a single alpha-blended featuremap.
+ if intervention is None:
+ return None
+ mask = intervention.get('mask', None)
+ if mask:
+ mask = torch.from_numpy(mask_to_numpy(mask))
+ maskpooling = intervention.get('maskpooling', 'max')
+ channels = {} # layer -> ([alpha, val], c)
+ for arec in intervention.get('ablations', []):
+ unit = arec['unit']
+ layer = arec['layer']
+ alpha = arec.get('alpha', 1.0)
+ if alpha is None:
+ alpha = 1.0
+ value = arec.get('value', 0.0)
+ if value is None:
+ value = 0.0
+ if alpha != 0.0 or value != 0.0:
+ if layer not in channels:
+ channels[layer] = torch.zeros(2, *layer_shapes[layer][1:])
+ channels[layer][0, unit] = alpha
+ channels[layer][1, unit] = value
+ if mask is not None:
+ for layer in channels:
+ layer_shape = layer_shapes[layer][2:]
+ if maskpooling == 'mean':
+ layer_mask = torch.nn.functional.adaptive_avg_pool2d(
+ mask[None,None,...], layer_shape)[0]
+ else:
+ layer_mask = torch.nn.functional.adaptive_max_pool2d(
+ mask[None,None,...], layer_shape)[0]
+ channels[layer][0] *= layer_mask
+ return channels
+
+def img2base64(imgarray, for_html=True, image_format='jpeg'):
+ '''
+ Converts a numpy array to a jpeg base64 url
+ '''
+ input_image_buff = BytesIO()
+ Image.fromarray(imgarray).save(input_image_buff, image_format,
+ quality=99, optimize=True, progressive=True)
+ res = base64.b64encode(input_image_buff.getvalue()).decode('ascii')
+ if for_html:
+ return 'data:image/' + image_format + ';base64,' + res
+ else:
+ return res
+
+def base642img(stringdata):
+ stringdata = re.sub('^(?:data:)?image/\w+;base64,', '', stringdata)
+ im = Image.open(BytesIO(base64.b64decode(stringdata)))
+ return numpy.array(im)
+
+def twitter_card(image_path, tweet_path, public_host):
+ return '''\
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Painting with GANs from MIT-IBM Watson AI Lab
+
This demo lets you modify a selection of meatningful GAN units for a generated image by simply painting.