taquynhnga commited on
Commit
18f2f54
1 Parent(s): 5d3c300

lfs track pickle & json & csv

Browse files
.gitattributes CHANGED
@@ -1,10 +1,3 @@
1
  *.pkl filter=lfs diff=lfs merge=lfs -text
2
  *.json filter=lfs diff=lfs merge=lfs -text
3
- *.md filter=lfs diff=lfs merge=lfs -text
4
- *.yml filter=lfs diff=lfs merge=lfs -text
5
- *.gitignore filter=lfs diff=lfs merge=lfs -text
6
- *.py filter=lfs diff=lfs merge=lfs -text
7
- *.csv filter=lfs diff=lfs merge=lfs -text
8
- *.dot filter=lfs diff=lfs merge=lfs -text
9
- *.html filter=lfs diff=lfs merge=lfs -text
10
- *.txt filter=lfs diff=lfs merge=lfs -text
 
1
  *.pkl filter=lfs diff=lfs merge=lfs -text
2
  *.json filter=lfs diff=lfs merge=lfs -text
3
+ .csv filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
.github/workflows/sync_to_huggingface_hub.yml CHANGED
@@ -1,3 +1,19 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e53686e913140e2b95d86fde5ca8411ce2ca06284ddddfb34d58d59c931b6ec6
3
- size 488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ - name: Push to hub
17
+ env:
18
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
19
+ run: git push --force https://taquynhnga:[email protected]/spaces/taquynhnga/CNNs-interpretation-visualization main
.gitignore CHANGED
@@ -1,3 +1,183 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:700edf4b441c3e125db4aa3145f426061b430a4e058e73f30e79de4759d0a71a
3
- size 2670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VSCode
2
+ .vscode/*
3
+ !.vscode/settings.json
4
+ !.vscode/tasks.json
5
+ !.vscode/launch.json
6
+ !.vscode/extensions.json
7
+ *.code-workspace
8
+ # Local History for Visual Studio Code
9
+ .history/
10
+
11
+ # Common credential files
12
+ **/credentials.json
13
+ **/client_secrets.json
14
+ **/client_secret.json
15
+ *creds*
16
+ *.dat
17
+ *password*
18
+ *.httr-oauth*
19
+
20
+ # Private Node Modules
21
+ node_modules/
22
+ creds.js
23
+
24
+ # Private Files
25
+ # *.json
26
+ # *.csv
27
+ # *.csv.gz
28
+ # *.tsv
29
+ # *.tsv.gz
30
+ # *.xlsx
31
+ git-large-file
32
+ deta_drive.py
33
+ secret_keys.py
34
+
35
+ # Large files
36
+ # data/preprocessed_image_net/
37
+ # data/activation/*.pkl
38
+ # data/activation/*.json
39
+ # data/layer_infos/*.pkl
40
+ # data/layer_infos/*.json
41
+
42
+ # Mac/OSX
43
+ .DS_Store
44
+
45
+
46
+ # Byte-compiled / optimized / DLL files
47
+ __pycache__/
48
+ *.py[cod]
49
+ *$py.class
50
+
51
+ # C extensions
52
+ *.so
53
+
54
+ # Distribution / packaging
55
+ .Python
56
+ build/
57
+ develop-eggs/
58
+ dist/
59
+ downloads/
60
+ eggs/
61
+ .eggs/
62
+ lib/
63
+ lib64/
64
+ parts/
65
+ sdist/
66
+ var/
67
+ wheels/
68
+ share/python-wheels/
69
+ *.egg-info/
70
+ .installed.cfg
71
+ *.egg
72
+ MANIFEST
73
+
74
+ # PyInstaller
75
+ # Usually these files are written by a python script from a template
76
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
77
+ *.manifest
78
+ *.spec
79
+
80
+ # Installer logs
81
+ pip-log.txt
82
+ pip-delete-this-directory.txt
83
+
84
+ # Unit test / coverage reports
85
+ htmlcov/
86
+ .tox/
87
+ .nox/
88
+ .coverage
89
+ .coverage.*
90
+ .cache
91
+ nosetests.xml
92
+ coverage.xml
93
+ *.cover
94
+ *.py,cover
95
+ .hypothesis/
96
+ .pytest_cache/
97
+ cover/
98
+
99
+ # Translations
100
+ *.mo
101
+ *.pot
102
+
103
+ # Django stuff:
104
+ *.log
105
+ local_settings.py
106
+ db.sqlite3
107
+ db.sqlite3-journal
108
+
109
+ # Flask stuff:
110
+ instance/
111
+ .webassets-cache
112
+
113
+ # Scrapy stuff:
114
+ .scrapy
115
+
116
+ # Sphinx documentation
117
+ docs/_build/
118
+
119
+ # PyBuilder
120
+ .pybuilder/
121
+ target/
122
+
123
+ # Jupyter Notebook
124
+ .ipynb_checkpoints
125
+
126
+ # IPython
127
+ profile_default/
128
+ ipython_config.py
129
+
130
+ # pyenv
131
+ # For a library or package, you might want to ignore these files since the code is
132
+ # intended to run in multiple environments; otherwise, check them in:
133
+ # .python-version
134
+
135
+ # pipenv
136
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
137
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
138
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
139
+ # install all needed dependencies.
140
+ #Pipfile.lock
141
+
142
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
143
+ __pypackages__/
144
+
145
+ # Celery stuff
146
+ celerybeat-schedule
147
+ celerybeat.pid
148
+
149
+ # SageMath parsed files
150
+ *.sage.py
151
+
152
+ # Environments
153
+ .env
154
+ .venv
155
+ env/
156
+ venv/
157
+ ENV/
158
+ env.bak/
159
+ venv.bak/
160
+
161
+ # Spyder project settings
162
+ .spyderproject
163
+ .spyproject
164
+
165
+ # Rope project settings
166
+ .ropeproject
167
+
168
+ # mkdocs documentation
169
+ /site
170
+
171
+ # mypy
172
+ .mypy_cache/
173
+ .dmypy.json
174
+ dmypy.json
175
+
176
+ # Pyre type checker
177
+ .pyre/
178
+
179
+ # pytype static type analyzer
180
+ .pytype/
181
+
182
+ # Cython debug symbols
183
+ cython_debug/
Home.py CHANGED
@@ -1,3 +1,13 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d5202c200b08978fee68e52e2eaaf26868071d71a681365116e9726d994c3f12
3
- size 256
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
3
+
4
+ import streamlit as st
5
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
6
+ import torch
7
+
8
+ st.set_page_config(layout='wide')
9
+ st.title('About')
10
+
11
+ st.write('Loaded 3 models')
12
+
13
+
README.md CHANGED
@@ -1,3 +1,18 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:64ec7eb74a38e942caa6dc3b16b6e2cdd8fdba340fc0f260b1ffc2627825d7e6
3
- size 533
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CNNs Interpretation Visualization
3
+ emoji: 🕯
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: streamlit
7
+ sdk_version: 1.10.0
8
+ app_file: Home.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Visualizing Interpretations of CNN models: ConvNeXt, ResNet and MobileNet
13
+
14
+ To be change name: CNNs-interpretation-visualization
15
+
16
+ This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
17
+
18
+ This repo lacks one more folder `data/preprocessed_image_net` which contains 50,000 preprocessed imagenet validation images saved in 5 pickle files.
backend/load_file.py CHANGED
@@ -1,3 +1,37 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b31a832332b15b85d097b62cdcf616e4f8c242886908c1386dce7c73ff5b1e4a
3
- size 1202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+
6
+ def load_pickle(filename):
7
+ with open(filename, 'rb') as file:
8
+ data = pickle.load(file)
9
+ return data
10
+
11
+ def save_pickle_to_json(filename):
12
+ ordered_dict = load_pickle(filename)
13
+ json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
14
+ with open(filename.replace('.pkl', '.json'), 'w') as f:
15
+ f.write(json_obj)
16
+
17
+ def load_json(filename):
18
+ with open(filename, 'r') as read_file:
19
+ loaded_dict = json.loads(read_file.read())
20
+ loaded_dict = OrderedDict(loaded_dict)
21
+ for k, v in loaded_dict.items():
22
+ loaded_dict[k] = np.asarray(v)
23
+ return loaded_dict
24
+
25
+ class NumpyEncoder(json.JSONEncoder):
26
+ def default(self, obj):
27
+ if isinstance(obj, np.ndarray):
28
+ return obj.tolist()
29
+ return json.JSONEncoder.default(self, obj)
30
+
31
+ # save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
32
+ # save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
33
+ # save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
34
+
35
+ file = load_json('data/layer_infos/convnext_layer_infos.json')
36
+ print(type(file))
37
+ print(type(file['embeddings.patch_embeddings']))
backend/maximally_activating_patches.py CHANGED
@@ -1,3 +1,43 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:92ef13b9479b334f8c60677cb33116d944f2d841d02a9f638c54e8d5d43dd3e5
3
- size 1556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import streamlit as st
3
+
4
+ from backend.load_file import load_json
5
+
6
+
7
+ @st.cache(allow_output_mutation=True)
8
+ def load_activation(filename):
9
+ activation = load_json(filename)
10
+ return activation
11
+
12
+ @st.cache(allow_output_mutation=True)
13
+ def load_dataset(data_index):
14
+ with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
15
+ dataset = pickle.load(file)
16
+ return dataset
17
+
18
+ def load_layer_infos(filename):
19
+ layer_infos = load_json(filename)
20
+ return layer_infos
21
+
22
+ def get_receptive_field_coordinates(layer_infos, layer_name, idx_x, idx_y):
23
+ """
24
+ layer_name: as in layer_infos keys (eg: 'encoder.stages[0].layers[0]')
25
+ idx_x: integer coordinate of width axis in feature maps. must < n
26
+ idx_y: integer coordinate of height axis in feature maps. must < n
27
+ """
28
+ layer_name = layer_name.replace('.dwconv', '').replace('.layernorm', '')
29
+ layer_name = layer_name.replace('.pwconv1', '').replace('.pwconv2', '').replace('.drop_path', '')
30
+ n = layer_infos[layer_name]['n']
31
+ j = layer_infos[layer_name]['j']
32
+ r = layer_infos[layer_name]['r']
33
+ start = layer_infos[layer_name]['start']
34
+ assert idx_x < n, f'n={n}'
35
+ assert idx_y < n, f'n={n}'
36
+
37
+ # image tensor (N, H, W, C) or (N, C, H, W) => image_patch=image[y1:y2, x1:x2]
38
+ center = (start + idx_x*j, start + idx_y*j)
39
+ x1, x2 = (max(center[0]-r/2, 0), max(center[0]+r/2, 0))
40
+ y1, y2 = (max(center[1]-r/2, 0), max(center[1]+r/2, 0))
41
+ x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2)
42
+
43
+ return x1, x2, y1, y2
backend/smooth_grad.py CHANGED
@@ -1,3 +1,233 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:34a1734def98b859f4162850f92032e255590f9f2bd28d853273b2e7c8022965
3
- size 7493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import PIL
4
+ from PIL import Image
5
+ import numpy as np
6
+ from matplotlib import pylab as P
7
+ import cv2
8
+
9
+ import torch
10
+ from torch.utils.data import TensorDataset
11
+ from torchvision import transforms
12
+
13
+ dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
14
+ sys.path.append(dirpath_to_modules)
15
+
16
+ from torchvex.base import ExplanationMethod
17
+ from torchvex.utils.normalization import clamp_quantile
18
+
19
+ def ShowImage(im, title='', ax=None):
20
+ image = np.array(im)
21
+ return image
22
+
23
+ def ShowGrayscaleImage(im, title='', ax=None):
24
+ if ax is None:
25
+ P.figure()
26
+ P.axis('off')
27
+ P.imshow(im, cmap=P.cm.gray, vmin=0, vmax=1)
28
+ P.title(title)
29
+ return P
30
+
31
+ def ShowHeatMap(im, title='', ax=None):
32
+ im = im - im.min()
33
+ im = im / im.max()
34
+ im = im.clip(0,1)
35
+ im = np.uint8(im * 255)
36
+
37
+ im = cv2.resize(im, (224,224))
38
+ image = cv2.resize(im, (224, 224))
39
+
40
+ # Apply JET colormap
41
+ color_heatmap = cv2.applyColorMap(image, cv2.COLORMAP_HOT)
42
+ # P.imshow(im, cmap='inferno')
43
+ # P.title(title)
44
+ return color_heatmap
45
+
46
+ def ShowMaskedImage(saliency_map, image, title='', ax=None):
47
+ """
48
+ Save saliency map on image.
49
+
50
+ Args:
51
+ image: Tensor of size (H,W,3)
52
+ saliency_map: Tensor of size (H,W,1)
53
+ """
54
+
55
+ # if ax is None:
56
+ # P.figure()
57
+ # P.axis('off')
58
+
59
+ saliency_map = saliency_map - saliency_map.min()
60
+ saliency_map = saliency_map / saliency_map.max()
61
+ saliency_map = saliency_map.clip(0,1)
62
+ saliency_map = np.uint8(saliency_map * 255)
63
+
64
+ saliency_map = cv2.resize(saliency_map, (224,224))
65
+ image = cv2.resize(image, (224, 224))
66
+
67
+ # Apply JET colormap
68
+ color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_HOT)
69
+
70
+ # Blend image with heatmap
71
+ img_with_heatmap = cv2.addWeighted(image, 0.4, color_heatmap, 0.6, 0)
72
+
73
+ # P.imshow(img_with_heatmap)
74
+ # P.title(title)
75
+ return img_with_heatmap
76
+
77
+ def LoadImage(file_path):
78
+ im = PIL.Image.open(file_path)
79
+ im = im.resize((224, 224))
80
+ im = np.asarray(im)
81
+ return im
82
+
83
+
84
+ def visualize_image_grayscale(image_3d, percentile=99):
85
+ r"""Returns a 3D tensor as a grayscale 2D tensor.
86
+ This method sums a 3D tensor across the absolute value of axis=2, and then
87
+ clips values at a given percentile.
88
+ """
89
+ image_2d = np.sum(np.abs(image_3d), axis=2)
90
+
91
+ vmax = np.percentile(image_2d, percentile)
92
+ vmin = np.min(image_2d)
93
+
94
+ return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1)
95
+
96
+ def visualize_image_diverging(image_3d, percentile=99):
97
+ r"""Returns a 3D tensor as a 2D tensor with positive and negative values.
98
+ """
99
+ image_2d = np.sum(image_3d, axis=2)
100
+
101
+ span = abs(np.percentile(image_2d, percentile))
102
+ vmin = -span
103
+ vmax = span
104
+
105
+ return np.clip((image_2d - vmin) / (vmax - vmin), -1, 1)
106
+
107
+
108
+ class SimpleGradient(ExplanationMethod):
109
+ def __init__(self, model, create_graph=False,
110
+ preprocess=None, postprocess=None):
111
+ super().__init__(model, preprocess, postprocess)
112
+ self.create_graph = create_graph
113
+
114
+ def predict(self, x):
115
+ return self.model(x)
116
+
117
+ @torch.enable_grad()
118
+ def process(self, inputs, target):
119
+ self.model.zero_grad()
120
+ inputs.requires_grad_(True)
121
+
122
+ out = self.model(inputs)
123
+ out = out if type(out) == torch.Tensor else out.logits
124
+
125
+ num_classes = out.size(-1)
126
+ onehot = torch.zeros(inputs.size(0), num_classes, *target.shape[1:])
127
+ onehot = onehot.to(dtype=inputs.dtype, device=inputs.device)
128
+ onehot.scatter_(1, target.unsqueeze(1), 1)
129
+
130
+ grad, = torch.autograd.grad(
131
+ (out*onehot).sum(), inputs, create_graph=self.create_graph
132
+ )
133
+
134
+ return grad
135
+
136
+
137
+ class SmoothGradient(ExplanationMethod):
138
+ def __init__(self, model, stdev_spread=0.15, num_samples=25,
139
+ magnitude=True, batch_size=-1,
140
+ create_graph=False, preprocess=None, postprocess=None):
141
+ super().__init__(model, preprocess, postprocess)
142
+ self.stdev_spread = stdev_spread
143
+ self.nsample = num_samples
144
+ self.create_graph = create_graph
145
+ self.magnitude = magnitude
146
+ self.batch_size = batch_size
147
+ if self.batch_size == -1:
148
+ self.batch_size = self.nsample
149
+
150
+ self._simgrad = SimpleGradient(model, create_graph)
151
+
152
+ def process(self, inputs, target):
153
+ self.model.zero_grad()
154
+
155
+ maxima = inputs.flatten(1).max(-1)[0]
156
+ minima = inputs.flatten(1).min(-1)[0]
157
+
158
+ stdev = self.stdev_spread * (maxima - minima).cpu()
159
+ stdev = stdev.view(inputs.size(0), 1, 1, 1).expand_as(inputs)
160
+ stdev = stdev.unsqueeze(0).expand(self.nsample, *[-1]*4)
161
+ noise = torch.normal(0, stdev)
162
+
163
+ target_expanded = target.unsqueeze(0).cpu()
164
+ target_expanded = target_expanded.expand(noise.size(0), -1)
165
+
166
+ noiseloader = torch.utils.data.DataLoader(
167
+ TensorDataset(noise, target_expanded), batch_size=self.batch_size
168
+ )
169
+
170
+ total_gradients = torch.zeros_like(inputs)
171
+ for noise, t_exp in noiseloader:
172
+ inputs_w_noise = inputs.unsqueeze(0) + noise.to(inputs.device)
173
+ inputs_w_noise = inputs_w_noise.view(-1, *inputs.shape[1:])
174
+ gradients = self._simgrad(inputs_w_noise, t_exp.view(-1))
175
+ gradients = gradients.view(self.batch_size, *inputs.shape)
176
+ if self.magnitude:
177
+ gradients = gradients.pow(2)
178
+ total_gradients = total_gradients + gradients.sum(0)
179
+
180
+ smoothed_gradient = total_gradients / self.nsample
181
+ return smoothed_gradient
182
+
183
+
184
+ def feed_forward(model_name, image, model=None, feature_extractor=None):
185
+ if model_name in ['ConvNeXt', 'ResNet']:
186
+ inputs = feature_extractor(image, return_tensors="pt")
187
+ logits = model(**inputs).logits
188
+ prediction_class = logits.argmax(-1).item()
189
+ else:
190
+ transform_images = transforms.Compose([
191
+ transforms.Resize(224),
192
+ transforms.ToTensor(),
193
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
194
+ input_tensor = transform_images(image)
195
+ inputs = input_tensor.unsqueeze(0)
196
+
197
+ output = model(inputs)
198
+ prediction_class = output.argmax(-1).item()
199
+ #prediction_label = model.config.id2label[prediction_class]
200
+ return inputs, prediction_class
201
+
202
+ def clip_gradient(gradient):
203
+ gradient = gradient.abs().sum(1, keepdim=True)
204
+ return clamp_quantile(gradient, q=0.99)
205
+
206
+ def fig2img(fig):
207
+ """Convert a Matplotlib figure to a PIL Image and return it"""
208
+ import io
209
+ buf = io.BytesIO()
210
+ fig.savefig(buf)
211
+ buf.seek(0)
212
+ img = Image.open(buf)
213
+ return img
214
+
215
+ def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25):
216
+ inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
217
+
218
+ smoothgrad_gen = SmoothGradient(
219
+ model, num_samples=num_samples, stdev_spread=0.1,
220
+ magnitude=False, postprocess=clip_gradient)
221
+
222
+ if type(inputs) != torch.Tensor:
223
+ inputs = inputs['pixel_values']
224
+
225
+ smoothgrad_mask = smoothgrad_gen(inputs, prediction_class)
226
+ smoothgrad_mask = smoothgrad_mask[0].numpy()
227
+ smoothgrad_mask = np.transpose(smoothgrad_mask, (1, 2, 0))
228
+
229
+ image = np.asarray(image)
230
+ # ori_image = ShowImage(image)
231
+ heat_map_image = ShowHeatMap(smoothgrad_mask)
232
+ masked_image = ShowMaskedImage(smoothgrad_mask, image)
233
+ return heat_map_image, masked_image
backend/utils.py CHANGED
@@ -1,3 +1,334 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fbde6798de2a984418ec73610da9c15ea8a5582ab366a8376ab1481d5cff4355
3
- size 9589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+
4
+ import io
5
+ from typing import List, Optional
6
+
7
+ import markdown
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import pandas as pd
11
+ import plotly.graph_objects as go
12
+ import streamlit as st
13
+ from plotly import express as px
14
+ from plotly.subplots import make_subplots
15
+ from tqdm import trange
16
+
17
+ @st.cache(allow_output_mutation=True)
18
+ def load_dataset(data_index):
19
+ with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
20
+ dataset = pickle.load(file)
21
+ return dataset
22
+
23
+ def load_dataset_dict():
24
+ dataset_dict = {}
25
+ progress_empty = st.empty()
26
+ text_empty = st.empty()
27
+ text_empty.write("Loading datasets...")
28
+ progress_bar = progress_empty.progress(0.0)
29
+ for data_index in trange(5):
30
+ dataset_dict[data_index] = load_dataset(data_index)
31
+ progress_bar.progress((data_index+1)/5)
32
+ progress_empty.empty()
33
+ text_empty.empty()
34
+ return dataset_dict
35
+
36
+ def make_grid(cols=None,rows=None):
37
+ grid = [0]*rows
38
+ for i in range(rows):
39
+ with st.container():
40
+ grid[i] = st.columns(cols)
41
+ return grid
42
+
43
+
44
+ def use_container_width_percentage(percentage_width:int = 75):
45
+ max_width_str = f"max-width: {percentage_width}%;"
46
+ st.markdown(f"""
47
+ <style>
48
+ .reportview-container .main .block-container{{{max_width_str}}}
49
+ </style>
50
+ """,
51
+ unsafe_allow_html=True,
52
+ )
53
+
54
+ matplotlib.use("Agg")
55
+ COLOR = "#31333f"
56
+ BACKGROUND_COLOR = "#ffffff"
57
+
58
+
59
+ def grid_demo():
60
+ """Main function. Run this to run the app"""
61
+ st.sidebar.title("Layout and Style Experiments")
62
+ st.sidebar.header("Settings")
63
+ st.markdown(
64
+ """
65
+ # Layout and Style Experiments
66
+
67
+ The basic question is: Can we create a multi-column dashboard with plots, numbers and text using
68
+ the [CSS Grid](https://gridbyexample.com/examples)?
69
+
70
+ Can we do it with a nice api?
71
+ Can have a dark theme?
72
+ """
73
+ )
74
+
75
+ select_block_container_style()
76
+ add_resources_section()
77
+
78
+ # My preliminary idea of an API for generating a grid
79
+ with Grid("1 1 1", color=COLOR, background_color=BACKGROUND_COLOR) as grid:
80
+ grid.cell(
81
+ class_="a",
82
+ grid_column_start=2,
83
+ grid_column_end=3,
84
+ grid_row_start=1,
85
+ grid_row_end=2,
86
+ ).markdown("# This is A Markdown Cell")
87
+ grid.cell("b", 2, 3, 2, 3).text("The cell to the left is a dataframe")
88
+ grid.cell("c", 3, 4, 2, 3).plotly_chart(get_plotly_fig())
89
+ grid.cell("d", 1, 2, 1, 3).dataframe(get_dataframe())
90
+ grid.cell("e", 3, 4, 1, 2).markdown(
91
+ "Try changing the **block container style** in the sidebar!"
92
+ )
93
+ grid.cell("f", 1, 3, 3, 4).text(
94
+ "The cell to the right is a matplotlib svg image"
95
+ )
96
+ grid.cell("g", 3, 4, 3, 4).pyplot(get_matplotlib_plt())
97
+
98
+
99
+ def add_resources_section():
100
+ """Adds a resources section to the sidebar"""
101
+ st.sidebar.header("Add_resources_section")
102
+ st.sidebar.markdown(
103
+ """
104
+ - [gridbyexample.com] (https://gridbyexample.com/examples/)
105
+ """
106
+ )
107
+
108
+
109
+ class Cell:
110
+ """A Cell can hold text, markdown, plots etc."""
111
+
112
+ def __init__(
113
+ self,
114
+ class_: str = None,
115
+ grid_column_start: Optional[int] = None,
116
+ grid_column_end: Optional[int] = None,
117
+ grid_row_start: Optional[int] = None,
118
+ grid_row_end: Optional[int] = None,
119
+ ):
120
+ self.class_ = class_
121
+ self.grid_column_start = grid_column_start
122
+ self.grid_column_end = grid_column_end
123
+ self.grid_row_start = grid_row_start
124
+ self.grid_row_end = grid_row_end
125
+ self.inner_html = ""
126
+
127
+ def _to_style(self) -> str:
128
+ return f"""
129
+ .{self.class_} {{
130
+ grid-column-start: {self.grid_column_start};
131
+ grid-column-end: {self.grid_column_end};
132
+ grid-row-start: {self.grid_row_start};
133
+ grid-row-end: {self.grid_row_end};
134
+ }}
135
+ """
136
+
137
+ def text(self, text: str = ""):
138
+ self.inner_html = text
139
+
140
+ def markdown(self, text):
141
+ self.inner_html = markdown.markdown(text)
142
+
143
+ def dataframe(self, dataframe: pd.DataFrame):
144
+ self.inner_html = dataframe.to_html()
145
+
146
+ def plotly_chart(self, fig):
147
+ self.inner_html = f"""
148
+ <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
149
+ <body>
150
+ <p>This should have been a plotly plot.
151
+ But since *script* tags are removed when inserting MarkDown/ HTML i cannot get it to workto work.
152
+ But I could potentially save to svg and insert that.</p>
153
+ <div id='divPlotly'></div>
154
+ <script>
155
+ var plotly_data = {fig.to_json()}
156
+ Plotly.react('divPlotly', plotly_data.data, plotly_data.layout);
157
+ </script>
158
+ </body>
159
+ """
160
+
161
+ def pyplot(self, fig=None, **kwargs):
162
+ string_io = io.StringIO()
163
+ plt.savefig(string_io, format="svg", fig=(2, 2))
164
+ svg = string_io.getvalue()[215:]
165
+ plt.close(fig)
166
+ self.inner_html = '<div height="200px">' + svg + "</div>"
167
+
168
+ def _to_html(self):
169
+ return f"""<div class="box {self.class_}">{self.inner_html}</div>"""
170
+
171
+
172
+ class Grid:
173
+ """A (CSS) Grid"""
174
+
175
+ def __init__(
176
+ self,
177
+ template_columns="1 1 1",
178
+ gap="10px",
179
+ background_color=COLOR,
180
+ color=BACKGROUND_COLOR,
181
+ ):
182
+ self.template_columns = template_columns
183
+ self.gap = gap
184
+ self.background_color = background_color
185
+ self.color = color
186
+ self.cells: List[Cell] = []
187
+
188
+ def __enter__(self):
189
+ return self
190
+
191
+ def __exit__(self, type, value, traceback):
192
+ st.markdown(self._get_grid_style(), unsafe_allow_html=True)
193
+ st.markdown(self._get_cells_style(), unsafe_allow_html=True)
194
+ st.markdown(self._get_cells_html(), unsafe_allow_html=True)
195
+
196
+ def _get_grid_style(self):
197
+ return f"""
198
+ <style>
199
+ .wrapper {{
200
+ display: grid;
201
+ grid-template-columns: {self.template_columns};
202
+ grid-gap: {self.gap};
203
+ background-color: {self.color};
204
+ color: {self.background_color};
205
+ }}
206
+ .box {{
207
+ background-color: {self.color};
208
+ color: {self.background_color};
209
+ border-radius: 0px;
210
+ padding: 0px;
211
+ font-size: 100%;
212
+ text-align: center;
213
+ }}
214
+ table {{
215
+ color: {self.color}
216
+ }}
217
+ </style>
218
+ """
219
+
220
+ def _get_cells_style(self):
221
+ return (
222
+ "<style>"
223
+ + "\n".join([cell._to_style() for cell in self.cells])
224
+ + "</style>"
225
+ )
226
+
227
+ def _get_cells_html(self):
228
+ return (
229
+ '<div class="wrapper">'
230
+ + "\n".join([cell._to_html() for cell in self.cells])
231
+ + "</div>"
232
+ )
233
+
234
+ def cell(
235
+ self,
236
+ class_: str = None,
237
+ grid_column_start: Optional[int] = None,
238
+ grid_column_end: Optional[int] = None,
239
+ grid_row_start: Optional[int] = None,
240
+ grid_row_end: Optional[int] = None,
241
+ ):
242
+ cell = Cell(
243
+ class_=class_,
244
+ grid_column_start=grid_column_start,
245
+ grid_column_end=grid_column_end,
246
+ grid_row_start=grid_row_start,
247
+ grid_row_end=grid_row_end,
248
+ )
249
+ self.cells.append(cell)
250
+ return cell
251
+
252
+
253
+ def select_block_container_style():
254
+ """Add selection section for setting setting the max-width and padding
255
+ of the main block container"""
256
+ st.sidebar.header("Block Container Style")
257
+ max_width_100_percent = st.sidebar.checkbox("Max-width: 100%?", False)
258
+ if not max_width_100_percent:
259
+ max_width = st.sidebar.slider("Select max-width in px", 100, 2000, 1200, 100)
260
+ else:
261
+ max_width = 1200
262
+ dark_theme = st.sidebar.checkbox("Dark Theme?", False)
263
+ padding_top = st.sidebar.number_input("Select padding top in rem", 0, 200, 5, 1)
264
+ padding_right = st.sidebar.number_input("Select padding right in rem", 0, 200, 1, 1)
265
+ padding_left = st.sidebar.number_input("Select padding left in rem", 0, 200, 1, 1)
266
+ padding_bottom = st.sidebar.number_input(
267
+ "Select padding bottom in rem", 0, 200, 10, 1
268
+ )
269
+ if dark_theme:
270
+ global COLOR
271
+ global BACKGROUND_COLOR
272
+ BACKGROUND_COLOR = "rgb(17,17,17)"
273
+ COLOR = "#fff"
274
+
275
+ _set_block_container_style(
276
+ max_width,
277
+ max_width_100_percent,
278
+ padding_top,
279
+ padding_right,
280
+ padding_left,
281
+ padding_bottom,
282
+ )
283
+
284
+
285
+ def _set_block_container_style(
286
+ max_width: int = 1200,
287
+ max_width_100_percent: bool = False,
288
+ padding_top: int = 5,
289
+ padding_right: int = 1,
290
+ padding_left: int = 1,
291
+ padding_bottom: int = 10,
292
+ ):
293
+ if max_width_100_percent:
294
+ max_width_str = f"max-width: 100%;"
295
+ else:
296
+ max_width_str = f"max-width: {max_width}px;"
297
+ st.markdown(
298
+ f"""
299
+ <style>
300
+ .reportview-container .main .block-container{{
301
+ {max_width_str}
302
+ padding-top: {padding_top}rem;
303
+ padding-right: {padding_right}rem;
304
+ padding-left: {padding_left}rem;
305
+ padding-bottom: {padding_bottom}rem;
306
+ }}
307
+ .reportview-container .main {{
308
+ color: {COLOR};
309
+ background-color: {BACKGROUND_COLOR};
310
+ }}
311
+ </style>
312
+ """,
313
+ unsafe_allow_html=True,
314
+ )
315
+
316
+
317
+ @st.cache
318
+ def get_dataframe() -> pd.DataFrame():
319
+ """Dummy DataFrame"""
320
+ data = [
321
+ {"quantity": 1, "price": 2},
322
+ {"quantity": 3, "price": 5},
323
+ {"quantity": 4, "price": 8},
324
+ ]
325
+ return pd.DataFrame(data)
326
+
327
+
328
+ def get_plotly_fig():
329
+ """Dummy Plotly Plot"""
330
+ return px.line(data_frame=get_dataframe(), x="quantity", y="price")
331
+
332
+
333
+ def get_matplotlib_plt():
334
+ get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))
data/ImageNet_metadata.csv CHANGED
The diff for this file is too large to render. See raw diff
 
data/dot_architectures/convnext_architecture.dot CHANGED
@@ -1,3 +1,194 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:41a258a40a93615638ae504770c14e44836c934badbe48f18148f5a750514ac9
3
- size 9108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ digraph convnext {
2
+ graph [labeljust=left ranksep=0.2]
3
+ node [fixedsize=true shape=box width=3]
4
+ edge [arrowhead=vee]
5
+ image [label=image shape=plaintext]
6
+ embeddings_conv [label="4x4 conv, 96, stride 4"]
7
+ image -> embeddings_conv
8
+ subgraph "cluster stage 1" {
9
+ graph [color="#e9f2f7" label="stage 1\l" style=filled]
10
+ node [color="#bcd9e7" style=filled]
11
+ subgraph "cluster stage 1 block 1" {
12
+ graph [color="#0f4158" label="block 1\l" style=dashed]
13
+ "stage 1 block 1 dwconv" [label="d7x7 conv, 96, stride 1"]
14
+ "stage 1 block 1 pwconv1" [label="1x1 conv, 384"]
15
+ "stage 1 block 1 pwconv2" [label="1x1 conv, 96"]
16
+ "stage 1 block 1 dwconv" -> "stage 1 block 1 pwconv1"
17
+ "stage 1 block 1 pwconv1" -> "stage 1 block 1 pwconv2"
18
+ }
19
+ subgraph "cluster stage 1 block 2" {
20
+ graph [color="#0f4158" label="block 2\l" style=dashed]
21
+ "stage 1 block 2 dwconv" [label="d7x7 conv, 96, stride 1"]
22
+ "stage 1 block 2 pwconv1" [label="1x1 conv, 384"]
23
+ "stage 1 block 2 pwconv2" [label="1x1 conv, 96"]
24
+ "stage 1 block 2 dwconv" -> "stage 1 block 2 pwconv1"
25
+ "stage 1 block 2 pwconv1" -> "stage 1 block 2 pwconv2"
26
+ }
27
+ subgraph "cluster stage 1 block 3" {
28
+ graph [color="#0f4158" label="block 3\l" style=dashed]
29
+ "stage 1 block 3 dwconv" [label="d7x7 conv, 96, stride 1"]
30
+ "stage 1 block 3 pwconv1" [label="1x1 conv, 384"]
31
+ "stage 1 block 3 pwconv2" [label="1x1 conv, 96"]
32
+ "stage 1 block 3 dwconv" -> "stage 1 block 3 pwconv1"
33
+ "stage 1 block 3 pwconv1" -> "stage 1 block 3 pwconv2"
34
+ }
35
+ "stage 1 block 1 pwconv2" -> "stage 1 block 2 dwconv"
36
+ "stage 1 block 2 pwconv2" -> "stage 1 block 3 dwconv"
37
+ }
38
+ subgraph "cluster stage 2" {
39
+ graph [color="#e9f7f3" label="stage 2\l" style=filled]
40
+ node [color="#bce7db" style=filled]
41
+ "stage 2 downsampling" [label="2x2 conv, 192, stride 2"]
42
+ subgraph "cluster stage 2 block 1" {
43
+ graph [color="#0f5844" label="block 1\l" style=dashed]
44
+ "stage 2 block 1 dwconv" [label="d7x7 conv, 192, stride 1"]
45
+ "stage 2 block 1 pwconv1" [label="1x1 conv, 768"]
46
+ "stage 2 block 1 pwconv2" [label="1x1 conv, 192"]
47
+ "stage 2 block 1 dwconv" -> "stage 2 block 1 pwconv1"
48
+ "stage 2 block 1 pwconv1" -> "stage 2 block 1 pwconv2"
49
+ }
50
+ subgraph "cluster stage 2 block 2" {
51
+ graph [color="#0f5844" label="block 2\l" style=dashed]
52
+ "stage 2 block 2 dwconv" [label="d7x7 conv, 192, stride 1"]
53
+ "stage 2 block 2 pwconv1" [label="1x1 conv, 768"]
54
+ "stage 2 block 2 pwconv2" [label="1x1 conv, 192"]
55
+ "stage 2 block 2 dwconv" -> "stage 2 block 2 pwconv1"
56
+ "stage 2 block 2 pwconv1" -> "stage 2 block 2 pwconv2"
57
+ }
58
+ subgraph "cluster stage 2 block 3" {
59
+ graph [color="#0f5844" label="block 3\l" style=dashed]
60
+ "stage 2 block 3 dwconv" [label="d7x7 conv, 192, stride 1"]
61
+ "stage 2 block 3 pwconv1" [label="1x1 conv, 768"]
62
+ "stage 2 block 3 pwconv2" [label="1x1 conv, 192"]
63
+ "stage 2 block 3 dwconv" -> "stage 2 block 3 pwconv1"
64
+ "stage 2 block 3 pwconv1" -> "stage 2 block 3 pwconv2"
65
+ }
66
+ "stage 2 downsampling" -> "stage 2 block 1 dwconv"
67
+ "stage 2 block 1 pwconv2" -> "stage 2 block 2 dwconv"
68
+ "stage 2 block 2 pwconv2" -> "stage 2 block 3 dwconv"
69
+ }
70
+ subgraph "cluster stage 3" {
71
+ graph [color="#e9f2f7" label="stage 3\l" style=filled]
72
+ node [color="#bcd9e7" style=filled]
73
+ "stage 3 downsampling" [label="2x2 conv, 384, stride 2"]
74
+ subgraph "cluster stage 3 block 1" {
75
+ graph [color="#0f4158" label="block 1\l" style=dashed]
76
+ "stage 3 block 1 dwconv" [label="d7x7 conv, 384, stride 1"]
77
+ "stage 3 block 1 pwconv1" [label="1x1 conv, 1536"]
78
+ "stage 3 block 1 pwconv2" [label="1x1 conv, 384"]
79
+ "stage 3 block 1 dwconv" -> "stage 3 block 1 pwconv1"
80
+ "stage 3 block 1 pwconv1" -> "stage 3 block 1 pwconv2"
81
+ }
82
+ subgraph "cluster stage 3 block 2" {
83
+ graph [color="#0f4158" label="block 2\l" style=dashed]
84
+ "stage 3 block 2 dwconv" [label="d7x7 conv, 384, stride 1"]
85
+ "stage 3 block 2 pwconv1" [label="1x1 conv, 1536"]
86
+ "stage 3 block 2 pwconv2" [label="1x1 conv, 384"]
87
+ "stage 3 block 2 dwconv" -> "stage 3 block 2 pwconv1"
88
+ "stage 3 block 2 pwconv1" -> "stage 3 block 2 pwconv2"
89
+ }
90
+ subgraph "cluster stage 3 block 3" {
91
+ graph [color="#0f4158" label="block 3\l" style=dashed]
92
+ "stage 3 block 3 dwconv" [label="d7x7 conv, 384, stride 1"]
93
+ "stage 3 block 3 pwconv1" [label="1x1 conv, 1536"]
94
+ "stage 3 block 3 pwconv2" [label="1x1 conv, 384"]
95
+ "stage 3 block 3 dwconv" -> "stage 3 block 3 pwconv1"
96
+ "stage 3 block 3 pwconv1" -> "stage 3 block 3 pwconv2"
97
+ }
98
+ subgraph "cluster stage 3 block 4" {
99
+ graph [color="#0f4158" label="block 4\l" style=dashed]
100
+ "stage 3 block 4 dwconv" [label="d7x7 conv, 384, stride 1"]
101
+ "stage 3 block 4 pwconv1" [label="1x1 conv, 1536"]
102
+ "stage 3 block 4 pwconv2" [label="1x1 conv, 384"]
103
+ "stage 3 block 4 dwconv" -> "stage 3 block 4 pwconv1"
104
+ "stage 3 block 4 pwconv1" -> "stage 3 block 4 pwconv2"
105
+ }
106
+ subgraph "cluster stage 3 block 5" {
107
+ graph [color="#0f4158" label="block 5\l" style=dashed]
108
+ "stage 3 block 5 dwconv" [label="d7x7 conv, 384, stride 1"]
109
+ "stage 3 block 5 pwconv1" [label="1x1 conv, 1536"]
110
+ "stage 3 block 5 pwconv2" [label="1x1 conv, 384"]
111
+ "stage 3 block 5 dwconv" -> "stage 3 block 5 pwconv1"
112
+ "stage 3 block 5 pwconv1" -> "stage 3 block 5 pwconv2"
113
+ }
114
+ subgraph "cluster stage 3 block 6" {
115
+ graph [color="#0f4158" label="block 6\l" style=dashed]
116
+ "stage 3 block 6 dwconv" [label="d7x7 conv, 384, stride 1"]
117
+ "stage 3 block 6 pwconv1" [label="1x1 conv, 1536"]
118
+ "stage 3 block 6 pwconv2" [label="1x1 conv, 384"]
119
+ "stage 3 block 6 dwconv" -> "stage 3 block 6 pwconv1"
120
+ "stage 3 block 6 pwconv1" -> "stage 3 block 6 pwconv2"
121
+ }
122
+ subgraph "cluster stage 3 block 7" {
123
+ graph [color="#0f4158" label="block 7\l" style=dashed]
124
+ "stage 3 block 7 dwconv" [label="d7x7 conv, 384, stride 1"]
125
+ "stage 3 block 7 pwconv1" [label="1x1 conv, 1536"]
126
+ "stage 3 block 7 pwconv2" [label="1x1 conv, 384"]
127
+ "stage 3 block 7 dwconv" -> "stage 3 block 7 pwconv1"
128
+ "stage 3 block 7 pwconv1" -> "stage 3 block 7 pwconv2"
129
+ }
130
+ subgraph "cluster stage 3 block 8" {
131
+ graph [color="#0f4158" label="block 8\l" style=dashed]
132
+ "stage 3 block 8 dwconv" [label="d7x7 conv, 384, stride 1"]
133
+ "stage 3 block 8 pwconv1" [label="1x1 conv, 1536"]
134
+ "stage 3 block 8 pwconv2" [label="1x1 conv, 384"]
135
+ "stage 3 block 8 dwconv" -> "stage 3 block 8 pwconv1"
136
+ "stage 3 block 8 pwconv1" -> "stage 3 block 8 pwconv2"
137
+ }
138
+ subgraph "cluster stage 3 block 9" {
139
+ graph [color="#0f4158" label="block 9\l" style=dashed]
140
+ "stage 3 block 9 dwconv" [label="d7x7 conv, 384, stride 1"]
141
+ "stage 3 block 9 pwconv1" [label="1x1 conv, 1536"]
142
+ "stage 3 block 9 pwconv2" [label="1x1 conv, 384"]
143
+ "stage 3 block 9 dwconv" -> "stage 3 block 9 pwconv1"
144
+ "stage 3 block 9 pwconv1" -> "stage 3 block 9 pwconv2"
145
+ }
146
+ "stage 3 downsampling" -> "stage 3 block 1 dwconv"
147
+ "stage 3 block 1 pwconv2" -> "stage 3 block 2 dwconv"
148
+ "stage 3 block 2 pwconv2" -> "stage 3 block 3 dwconv"
149
+ "stage 3 block 3 pwconv2" -> "stage 3 block 4 dwconv"
150
+ "stage 3 block 4 pwconv2" -> "stage 3 block 5 dwconv"
151
+ "stage 3 block 5 pwconv2" -> "stage 3 block 6 dwconv"
152
+ "stage 3 block 6 pwconv2" -> "stage 3 block 7 dwconv"
153
+ "stage 3 block 7 pwconv2" -> "stage 3 block 8 dwconv"
154
+ "stage 3 block 8 pwconv2" -> "stage 3 block 9 dwconv"
155
+ }
156
+ subgraph "cluster stage 4" {
157
+ graph [color="#e9f7f3" label="stage 4\l" style=filled]
158
+ node [color="#bce7db" style=filled]
159
+ "stage 4 downsampling" [label="2x2 conv, 768, stride 2"]
160
+ subgraph "cluster stage 4 block 1" {
161
+ graph [color="#0f5844" label="block 1\l" style=dashed]
162
+ "stage 4 block 1 dwconv" [label="d7x7 conv, 768, stride 1"]
163
+ "stage 4 block 1 pwconv1" [label="1x1 conv, 3072"]
164
+ "stage 4 block 1 pwconv2" [label="1x1 conv, 768"]
165
+ "stage 4 block 1 dwconv" -> "stage 4 block 1 pwconv1"
166
+ "stage 4 block 1 pwconv1" -> "stage 4 block 1 pwconv2"
167
+ }
168
+ subgraph "cluster stage 4 block 2" {
169
+ graph [color="#0f5844" label="block 2\l" style=dashed]
170
+ "stage 4 block 2 dwconv" [label="d7x7 conv, 768, stride 1"]
171
+ "stage 4 block 2 pwconv1" [label="1x1 conv, 3072"]
172
+ "stage 4 block 2 pwconv2" [label="1x1 conv, 768"]
173
+ "stage 4 block 2 dwconv" -> "stage 4 block 2 pwconv1"
174
+ "stage 4 block 2 pwconv1" -> "stage 4 block 2 pwconv2"
175
+ }
176
+ subgraph "cluster stage 4 block 3" {
177
+ graph [color="#0f5844" label="block 3\l" style=dashed]
178
+ "stage 4 block 3 dwconv" [label="d7x7 conv, 768, stride 1"]
179
+ "stage 4 block 3 pwconv1" [label="1x1 conv, 3072"]
180
+ "stage 4 block 3 pwconv2" [label="1x1 conv, 768"]
181
+ "stage 4 block 3 dwconv" -> "stage 4 block 3 pwconv1"
182
+ "stage 4 block 3 pwconv1" -> "stage 4 block 3 pwconv2"
183
+ }
184
+ "stage 4 downsampling" -> "stage 4 block 1 dwconv"
185
+ "stage 4 block 1 pwconv2" -> "stage 4 block 2 dwconv"
186
+ "stage 4 block 2 pwconv2" -> "stage 4 block 3 dwconv"
187
+ }
188
+ "stage 1 block 3 pwconv2" -> "stage 2 downsampling"
189
+ "stage 2 block 3 pwconv2" -> "stage 3 downsampling"
190
+ "stage 3 block 9 pwconv2" -> "stage 4 downsampling"
191
+ embeddings_conv -> "stage 1 block 1 dwconv"
192
+ "stage 4 block 3 pwconv2" -> "output vector"
193
+ "output vector" [label="output vector" shape=plaintext]
194
+ }
data/preprocessed_image_net/val_data_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2698bdc240555e2a46a40936df87275bc5852142d30e921ae0dad9289b0f576f
3
+ size 906108480
data/preprocessed_image_net/val_data_1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21780d77e212695dbee84d6d2ad17a5a520bc1634f68e1c8fd120f069ad76da1
3
+ size 907109023
data/preprocessed_image_net/val_data_2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cfc83b78420baa1b2c3a8da92e7fba1f33443d506f483ecff13cdba2035ab3c
3
+ size 907435149
data/preprocessed_image_net/val_data_3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f5e2c7cb4d6bae17fbd062a0b46f2cee457ad466b725f7bdf0f8426069cafee
3
+ size 906089333
data/preprocessed_image_net/val_data_4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed53c87ec8b9945db31f910eb44b7e3092324643de25ea53a99fc29137df854
3
+ size 905439763
frontend/__init__.py CHANGED
@@ -1,3 +1,6 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:34811d4931ed1903c27bc1e6b639b0e9cf5f4d0aaff202daa0f853660097739d
3
- size 138
 
 
 
 
1
+ import streamlit.components.v1 as components
2
+
3
+ on_click_graph = components.declare_component(
4
+ "on_click_graph",
5
+ path="./frontend"
6
+ )
frontend/index.html CHANGED
@@ -1,3 +1,204 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0bbc20b497b89adc1a1c2882055de0c04401843d59f0b1231d11281b1abe197b
3
- size 6651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+
3
+ <head>
4
+ <style type="text/css">
5
+ </style>
6
+ </head>
7
+
8
+ <!--
9
+ ----------------------------------------------------
10
+ Your custom static HTML goes in the body:
11
+ -->
12
+
13
+ <body>
14
+ </body>
15
+
16
+ <script type="text/javascript">
17
+ // Helper function to send type and data messages to Streamlit client
18
+
19
+ const SET_COMPONENT_VALUE = "streamlit:setComponentValue"
20
+ const RENDER = "streamlit:render"
21
+ const COMPONENT_READY = "streamlit:componentReady"
22
+ const SET_FRAME_HEIGHT = "streamlit:setFrameHeight"
23
+ var HIGHTLIGHT_COLOR;
24
+ var original_colors;
25
+
26
+ function _sendMessage(type, data) {
27
+ // copy data into object
28
+ var outboundData = Object.assign({
29
+ isStreamlitMessage: true,
30
+ type: type,
31
+ }, data)
32
+
33
+ if (type == SET_COMPONENT_VALUE) {
34
+ console.log("_sendMessage data: ", SET_COMPONENT_VALUE)
35
+ // console.log("_sendMessage data: " + JSON.stringify(data))
36
+ // console.log("_sendMessage outboundData: " + JSON.stringify(outboundData))
37
+ }
38
+
39
+ window.parent.postMessage(outboundData, "*")
40
+ }
41
+
42
+ function initialize(pipeline) {
43
+
44
+ // Hook Streamlit's message events into a simple dispatcher of pipeline handlers
45
+ window.addEventListener("message", (event) => {
46
+ if (event.data.type == RENDER) {
47
+ // The event.data.args dict holds any JSON-serializable value
48
+ // sent from the Streamlit client. It is already deserialized.
49
+ pipeline.forEach(handler => {
50
+ handler(event.data.args)
51
+ })
52
+ }
53
+ })
54
+
55
+ _sendMessage(COMPONENT_READY, { apiVersion: 1 });
56
+
57
+ // Component should be mounted by Streamlit in an iframe, so try to autoset the iframe height.
58
+ window.addEventListener("load", () => {
59
+ window.setTimeout(function () {
60
+ setFrameHeight(document.documentElement.clientHeight)
61
+ }, 0)
62
+ })
63
+
64
+ // Optionally, if auto-height computation fails, you can manually set it
65
+ // (uncomment below)
66
+ setFrameHeight(0)
67
+ }
68
+
69
+ function setFrameHeight(height) {
70
+ _sendMessage(SET_FRAME_HEIGHT, { height: height })
71
+ }
72
+
73
+ // The `data` argument can be any JSON-serializable value.
74
+ function notifyHost(data) {
75
+ _sendMessage(SET_COMPONENT_VALUE, data)
76
+ }
77
+
78
+ function changeButtonColor(button, color) {
79
+ pol = button.querySelectorAll('polygon')[0]
80
+ pol.setAttribute('fill', color)
81
+ pol.setAttribute('stroke', color)
82
+ }
83
+
84
+ function getButtonColor(button) {
85
+ pol = button.querySelectorAll('polygon')[0]
86
+ return pol.getAttribute('fill')
87
+ }
88
+ // ----------------------------------------------------
89
+ // Your custom functionality for the component goes here:
90
+
91
+ function toggle(button) {
92
+ group = 'node'
93
+ let button_color;
94
+ nodes = window.parent.document.getElementsByClassName('node')
95
+ console.log("nodes.length = ", nodes.length)
96
+ // for (let i = 0; i < nodes.length; i++) {
97
+ // console.log(nodes.item(i))
98
+ // }
99
+ console.log("selected button ", button, button.getAttribute('class'), button.id)
100
+
101
+ for (let i = 0; i < nodes.length; i++) {
102
+ polygons = nodes.item(i).querySelectorAll('polygon')
103
+ if (polygons.length == 0) {
104
+ continue
105
+ }
106
+ if (button.id == nodes.item(i).id & button.getAttribute('class').includes("off")) {
107
+ button.setAttribute('class', group + " on")
108
+ button_color = original_colors[i]
109
+
110
+ } else if (button.id == nodes.item(i).id & button.getAttribute('class').includes("on")) {
111
+ button.setAttribute('class', group + " off")
112
+ button_color = original_colors[i]
113
+ } else if (button.id == nodes.item(i).id) {
114
+ button.setAttribute('class', group + " on")
115
+ button_color = original_colors[i]
116
+
117
+ } else if (button.id != nodes.item(i).id & nodes.item(i).getAttribute('class').includes("on")) {
118
+ nodes.item(i).className = group + " off"
119
+ } else {
120
+ nodes.item(i).className = group + " off"
121
+ }
122
+ }
123
+
124
+ nodes = window.parent.document.getElementsByClassName('node')
125
+ actions = []
126
+ for (let i = 0; i < nodes.length; i++) {
127
+ polygons = nodes.item(i).querySelectorAll('polygon')
128
+ if (polygons.length == 0) {
129
+ continue
130
+ }
131
+ btn = nodes.item(i)
132
+ ori_color = original_colors[i]
133
+ color = btn.querySelectorAll('polygon')[0].getAttribute('fill')
134
+ actions.push({ "action": btn.getAttribute("class").includes("on"), "original_color": ori_color, "color": color})
135
+ }
136
+
137
+ states = {}
138
+ states['choice'] = {
139
+ "node_title": button.querySelectorAll("title")[0].innerHTML,
140
+ "node_id": button.id,
141
+ "state": {
142
+ "action": button.getAttribute("class").includes("on"),
143
+ "original_color": button_color,
144
+ "color": button.querySelectorAll('polygon')[0].getAttribute('fill')
145
+ }
146
+ }
147
+ states["options"] = {"states": actions }
148
+
149
+ notifyHost({
150
+ value: states,
151
+ dataType: "json",
152
+ })
153
+ }
154
+
155
+ // ----------------------------------------------------
156
+ // Here you can customize a pipeline of handlers for
157
+ // inbound properties from the Streamlit client app
158
+
159
+ // Set initial value sent from Streamlit!
160
+ function initializeProps_Handler(props) {
161
+ HIGHTLIGHT_COLOR = props['hightlight_color']
162
+ original_colors = []
163
+ // nodes = document.getElementsByClassName('node')
164
+ nodes = window.parent.document.getElementsByClassName('node')
165
+ console.log(nodes)
166
+ for (let i = 0; i < nodes.length; i++) {
167
+ // color = nodes.item(i).getElementsByTagName('POLYGON')[0].getAttribute("fill")
168
+ // nodes.item(i).addEventListener("click", toggle)
169
+ polygons = nodes.item(i).querySelectorAll('polygon')
170
+ if (polygons.length == 0) {
171
+ original_colors.push('none')
172
+ continue
173
+ }
174
+
175
+ color = polygons[0].getAttribute("fill")
176
+ if (!nodes.item(i).hasAttribute('color')) {
177
+ nodes.item(i).setAttribute("color", color)
178
+ original_colors.push(color)
179
+ } else {
180
+ original_colors.push(nodes.item(i).getAttribute("color"))
181
+ }
182
+ nodes.item(i).addEventListener("click", function (event) {toggle(this)})
183
+ }
184
+ // console.log("original colors:", original_colors)
185
+ }
186
+ // Access values sent from Streamlit!
187
+ function dataUpdate_Handler(props) {
188
+ console.log('dataUpdate_Handler...........')
189
+ let msgLabel = document.getElementById("message_label")
190
+ }
191
+ // Simply log received data dictionary
192
+ function log_Handler(props) {
193
+ console.log("Received from Streamlit: " + JSON.stringify(props))
194
+ }
195
+
196
+ let pipeline = [initializeProps_Handler, dataUpdate_Handler, log_Handler]
197
+
198
+ // ----------------------------------------------------
199
+ // Finally, initialize component passing in pipeline
200
+ initialize(pipeline)
201
+
202
+ </script>
203
+
204
+ </html>
load_file.py CHANGED
@@ -1,3 +1,37 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b31a832332b15b85d097b62cdcf616e4f8c242886908c1386dce7c73ff5b1e4a
3
- size 1202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+
6
+ def load_pickle(filename):
7
+ with open(filename, 'rb') as file:
8
+ data = pickle.load(file)
9
+ return data
10
+
11
+ def save_pickle_to_json(filename):
12
+ ordered_dict = load_pickle(filename)
13
+ json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
14
+ with open(filename.replace('.pkl', '.json'), 'w') as f:
15
+ f.write(json_obj)
16
+
17
+ def load_json(filename):
18
+ with open(filename, 'r') as read_file:
19
+ loaded_dict = json.loads(read_file.read())
20
+ loaded_dict = OrderedDict(loaded_dict)
21
+ for k, v in loaded_dict.items():
22
+ loaded_dict[k] = np.asarray(v)
23
+ return loaded_dict
24
+
25
+ class NumpyEncoder(json.JSONEncoder):
26
+ def default(self, obj):
27
+ if isinstance(obj, np.ndarray):
28
+ return obj.tolist()
29
+ return json.JSONEncoder.default(self, obj)
30
+
31
+ # save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
32
+ # save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
33
+ # save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
34
+
35
+ file = load_json('data/layer_infos/convnext_layer_infos.json')
36
+ print(type(file))
37
+ print(type(file['embeddings.patch_embeddings']))
pages/1_Maximally_activating_patches.py CHANGED
@@ -1,3 +1,161 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b726f208810db8cd10fddbbd73a60f6676a4bf415225565d5bcc4049b19a310e
3
- size 7434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+
4
+ from plotly.subplots import make_subplots
5
+ import plotly.graph_objects as go
6
+
7
+ import graphviz
8
+
9
+ from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
10
+ from frontend import on_click_graph
11
+ from backend.utils import load_dataset_dict
12
+
13
+ HIGHTLIGHT_COLOR = '#e7bcc5'
14
+ st.set_page_config(layout='wide')
15
+
16
+ # -------------------------- LOAD DATASET ---------------------------------
17
+ dataset_dict = load_dataset_dict()
18
+
19
+ # -------------------------- LOAD GRAPH -----------------------------------
20
+
21
+ def load_dot_to_graph(filename):
22
+ dot = graphviz.Source.from_file(filename)
23
+ source_lines = str(dot).splitlines()
24
+ source_lines.pop(0)
25
+ source_lines.pop(-1)
26
+ graph = graphviz.Digraph()
27
+ graph.body += source_lines
28
+ return graph, dot
29
+
30
+ st.title('Maximally activating image patches')
31
+ st.write('Visualize image patches that maximize the activation of layers in three models: ConvNeXt, ResNet, MobileNet')
32
+
33
+ # st.header('ConvNeXt')
34
+ convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
35
+ convnext_graph = load_dot_to_graph(convnext_dot_file)[0]
36
+
37
+ convnext_graph.graph_attr['size'] = '4,40'
38
+
39
+ # -------------------------- DISPLAY GRAPH -----------------------------------
40
+
41
+ def chosen_node_text(clicked_node_title):
42
+ clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
43
+ stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
44
+ block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
45
+ layer_id = clicked_node_title.split()[-1]
46
+
47
+ if 'embeddings' in layer_id:
48
+ display_text = 'Patchify layer'
49
+ activation_key = 'embeddings.patch_embeddings'
50
+ elif 'downsampling' in layer_id:
51
+ display_text = f'Stage {stage_id} > Downsampling layer'
52
+ activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
53
+ else:
54
+ display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
55
+ activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
56
+ return display_text, activation_key
57
+
58
+
59
+ props = {
60
+ 'hightlight_color': HIGHTLIGHT_COLOR,
61
+ 'initial_state': {
62
+ 'group_1_header': 'Choose an option from group 1',
63
+ 'group_2_header': 'Choose an option from group 2'
64
+ }
65
+ }
66
+
67
+ convnext_tab, resnet_tab, mobilenet_tab = st.tabs(['ConvNeXt', 'ResNet', 'MobileNet'])
68
+
69
+ with convnext_tab:
70
+ col1, col2 = st.columns((2,5))
71
+ col1.markdown("#### Architecture")
72
+ col1.write('')
73
+ col1.write('Click on a layer below to generate top-k maximally activating image patches')
74
+ col1.graphviz_chart(convnext_graph)
75
+
76
+ with col2:
77
+ st.markdown("#### Output")
78
+ nodes = on_click_graph(key='toggle_buttons', **props)
79
+
80
+ # -------------------------- DISPLAY OUTPUT -----------------------------------
81
+
82
+ if nodes != None:
83
+ clicked_node_title = nodes["choice"]["node_title"]
84
+ clicked_node_id = nodes["choice"]["node_id"]
85
+ display_text, activation_key = chosen_node_text(clicked_node_title)
86
+ col2.write(f'**Chosen layer:** {display_text}')
87
+ # col2.write(f'**Activation key:** {activation_key}')
88
+
89
+ hightlight_syle = f'''
90
+ <style>
91
+ div[data-stale]:has(iframe) {{
92
+ height: 0;
93
+ }}
94
+ #{clicked_node_id}>polygon {{
95
+ fill: {HIGHTLIGHT_COLOR};
96
+ stroke: {HIGHTLIGHT_COLOR};
97
+ }}
98
+ </style>
99
+ '''
100
+ col2.markdown(hightlight_syle, unsafe_allow_html=True)
101
+
102
+ with col2:
103
+ layer_infos = None
104
+ with st.form('top_k_form'):
105
+ activation_path = './data/activation/convnext_activation.json'
106
+ activation = load_activation(activation_path)
107
+ num_channels = activation[activation_key].shape[1]
108
+
109
+ top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
110
+ channel_start, channel_end = st.slider(
111
+ 'Choose channel range of this layer (recommend to choose small range less than 30)',
112
+ 1, num_channels, value=(1, 30))
113
+ summit_button = st.form_submit_button('Generate image patches')
114
+ if summit_button:
115
+
116
+ activation = activation[activation_key][:top_k,:,:]
117
+ layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
118
+ # st.write(channel_start, channel_end)
119
+ # st.write(activation.shape, activation.shape[1])
120
+
121
+ if layer_infos != None:
122
+ num_cols, num_rows = top_k, channel_end - channel_start + 1
123
+ # num_rows = activation.shape[1]
124
+ top_k_coor_max_ = activation
125
+ st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
126
+
127
+ for row in range(channel_start, channel_end+1):
128
+ if row == channel_start:
129
+ top_margin = 50
130
+ fig = make_subplots(
131
+ rows=1, cols=num_cols,
132
+ subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
133
+ else:
134
+ top_margin = 0
135
+ fig = make_subplots(rows=1, cols=num_cols)
136
+ for col in range(1, num_cols+1):
137
+ k, c = col-1, row-1
138
+ img_index = int(top_k_coor_max_[k, c, 3])
139
+ activation_value = top_k_coor_max_[k, c, 0]
140
+ img = dataset_dict[img_index//10_000][img_index%10_000]['image']
141
+ class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
142
+ class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
143
+
144
+ idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
145
+ x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
146
+ img = np.array(img)[y1:y2, x1:x2, :]
147
+
148
+ hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
149
+ fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
150
+ fig.update_xaxes(showticklabels=False, showgrid=False)
151
+ fig.update_yaxes(showticklabels=False, showgrid=False)
152
+ fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
153
+ fig.update_layout(showlegend=False, yaxis_title=row)
154
+ fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
155
+ fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
156
+ st.plotly_chart(fig, use_container_width=True)
157
+
158
+
159
+ else:
160
+ col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
161
+ col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
pages/2_SmoothGrad.py CHANGED
@@ -1,3 +1,145 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec3b4de671bf21ae9c1aed94b70fb93cf400271301858c72a8d07755da9d61aa
3
- size 6240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import random
5
+ from backend.utils import make_grid, load_dataset
6
+
7
+ from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
8
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
9
+ import torch
10
+
11
+ from matplotlib.backends.backend_agg import RendererAgg
12
+ _lock = RendererAgg.lock
13
+
14
+ st.set_page_config(layout='wide')
15
+ BACKGROUND_COLOR = '#bcd0e7'
16
+
17
+
18
+ st.title('Feature attribution with SmoothGrad')
19
+ st.write('Which features are responsible for the current prediction? ')
20
+
21
+ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
22
+
23
+ # --------------------------- LOAD function -----------------------------
24
+
25
+ @st.cache(allow_output_mutation=True)
26
+ def load_images(image_ids):
27
+ images = []
28
+ for image_id in image_ids:
29
+ dataset = load_dataset(image_id//10000)
30
+ images.append(dataset[image_id%10000])
31
+ return images
32
+
33
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
34
+ def load_model(model_name):
35
+ with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
36
+ if model_name == 'ResNet':
37
+ model_file_path = 'microsoft/resnet-50'
38
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
39
+ model = AutoModelForImageClassification.from_pretrained(model_file_path)
40
+ model.eval()
41
+ elif model_name == 'ConvNeXt':
42
+ model_file_path = 'facebook/convnext-tiny-224'
43
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
44
+ model = AutoModelForImageClassification.from_pretrained(model_file_path)
45
+ model.eval()
46
+ else:
47
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
48
+ model.eval()
49
+ feature_extractor = None
50
+ return model, feature_extractor
51
+
52
+ images = []
53
+ image_ids = []
54
+ # INPUT ------------------------------
55
+ st.header('Input')
56
+ with st.form('smooth_grad_form'):
57
+ st.markdown('**Model and Input Setting**')
58
+ selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
59
+ # selected_image_set = st.selectbox('Image set', ['Random set', 'User-defined set'])
60
+ selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
61
+
62
+ # if selected_image_set == 'Class set':
63
+ # class_labels = imagenet_df.ClassLabel.unique().tolist()
64
+ # class_labels.sort()
65
+ # selected_classes = st.multiselect('Class filter', options=['All'] + class_labels)
66
+ # if not ('All' in selected_classes or len(selected_classes) == 0):
67
+ # imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
68
+ # no_images = st.slider('Number of images', 1, len(imagenet_df), value=10)
69
+ # image_ids = random.sample(imagenet_df.index.tolist(), k=no_images)
70
+
71
+
72
+ # user_defined_button = st.form_submit_button('User-defined set')
73
+ # random_set_button = st.form_submit_button('Random set')
74
+
75
+ # if user_defined_button:
76
+ # text = st.text_area('Specific Image IDs', value='0')
77
+ # image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
78
+ # if random_set_button:
79
+ # no_images = st.slider('Number of images', 1, 50, value=10)
80
+ # image_ids = random.sample(list(range(50_000)), k=no_images)
81
+
82
+ summit_button = st.form_submit_button('Set')
83
+ if summit_button:
84
+ setting_container = st.container()
85
+ # for id in image_ids:
86
+ # images = load_images(image_ids)
87
+
88
+ with st.form('2nd_form'):
89
+ st.markdown('**Image set setting**')
90
+ if selected_image_set == 'Random set':
91
+ no_images = st.slider('Number of images', 1, 50, value=10)
92
+ image_ids = random.sample(list(range(50_000)), k=no_images)
93
+ else:
94
+ text = st.text_area('Specific Image IDs', value='0')
95
+ image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
96
+
97
+ run_button = st.form_submit_button('Display output')
98
+ if run_button:
99
+ for id in image_ids:
100
+ images = load_images(image_ids)
101
+
102
+ st.header('Output')
103
+
104
+ models = {}
105
+ feature_extractors = {}
106
+
107
+ for i, model_name in enumerate(selected_models):
108
+ models[model_name], feature_extractors[model_name] = load_model(model_name)
109
+
110
+
111
+ # DISPLAY ----------------------------------
112
+ header_cols = st.columns([1, 1] + [2]*len(selected_models))
113
+ header_cols[0].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Image ID</b></div>', unsafe_allow_html=True)
114
+ header_cols[1].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Original Image</b></div>', unsafe_allow_html=True)
115
+ for i, model_name in enumerate(selected_models):
116
+ header_cols[i + 2].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>{model_name}</b></div>', unsafe_allow_html=True)
117
+
118
+ grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
119
+ # grids[0][0].write('Image ID')
120
+ # grids[0][1].write('Original image')
121
+
122
+ # for i, model_name in enumerate(selected_models):
123
+ # models[model_name], feature_extractors[model_name] = load_model(model_name)
124
+
125
+
126
+ @st.cache(allow_output_mutation=True)
127
+ def generate_images(image, model_name):
128
+ return generate_smoothgrad_mask(
129
+ image, model_name,
130
+ models[model_name], feature_extractors[model_name], num_samples=10)
131
+
132
+ with _lock:
133
+ for j, (image_id, image_dict) in enumerate(zip(image_ids, images)):
134
+ grids[j][0].write(f'{image_id}. {image_dict["label"]}')
135
+ image = image_dict['image']
136
+ ori_image = ShowImage(np.asarray(image))
137
+ grids[j][1].image(ori_image)
138
+
139
+ for i, model_name in enumerate(selected_models):
140
+ # ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
141
+ # model_name, models[model_name], feature_extractors[model_name], num_samples=10)
142
+ heatmap_image, masked_image = generate_images(image, model_name)
143
+ # grids[j][1].image(ori_image)
144
+ grids[j][i*2+2].image(heatmap_image)
145
+ grids[j][i*2+3].image(masked_image)
pages/3_ImageNet1k.py CHANGED
@@ -1,3 +1,48 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:37e96fb5add3af83c7874ce57866a9254ac934813181649811b34a9925993bf2
3
- size 1736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ from backend.utils import load_dataset, use_container_width_percentage
5
+
6
+ st.title('ImageNet-1k')
7
+ st.markdown('This page shows the summary of 50,000 images in the validation set of [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)')
8
+
9
+ # SCREEN_WIDTH, SCREEN_HEIGHT = 2560, 1664
10
+
11
+ with st.spinner("Loading dataset..."):
12
+ dataset_dict = {}
13
+ for data_index in range(5):
14
+ dataset_dict[data_index] = load_dataset(data_index)
15
+
16
+ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
17
+
18
+ class_labels = imagenet_df.ClassLabel.unique().tolist()
19
+ class_labels.sort()
20
+ selected_classes = st.multiselect('Class filter: ', options=['All'] + class_labels)
21
+ if not ('All' in selected_classes or len(selected_classes) == 0):
22
+ imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
23
+ # st.write(class_labels)
24
+
25
+ col1, col2 = st.columns([2, 1])
26
+ with col1:
27
+ st.dataframe(imagenet_df)
28
+ use_container_width_percentage(100)
29
+
30
+ with col2:
31
+ st.text_area('Type anything here to copy later :)')
32
+ image = None
33
+ with st.form("display image"):
34
+ img_index = st.text_input('Image ID to display')
35
+ try:
36
+ img_index = int(img_index)
37
+ except:
38
+ pass
39
+
40
+ submitted = st.form_submit_button('Display this image')
41
+ if submitted:
42
+ image = dataset_dict[img_index//10_000][img_index%10_000]['image']
43
+ class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
44
+ class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
45
+ if image != None:
46
+ st.image(image)
47
+ st.write('**Class label:** ', class_label)
48
+ st.write('\n**Class id:** ', str(class_id))
requirements.txt CHANGED
@@ -1,3 +1,17 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:11db9a919c5ae02c8466823edd349608bbd3c490ca514cfa24b16434e7aea07b
3
- size 285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ captum==0.5.0
2
+ deta==1.1.0
3
+ graphviz==0.20.1
4
+ Markdown==3.4.1
5
+ matplotlib==3.6.2
6
+ numpy==1.22.3
7
+ opencv_python_headless==4.6.0.66
8
+ pandas==1.5.2
9
+ Pillow==9.3.0
10
+ plotly==5.11.0
11
+ scipy==1.9.3
12
+ setuptools==65.5.0
13
+ streamlit==1.15.2
14
+ torch==1.10.1
15
+ torchvision==0.11.2
16
+ tqdm==4.64.1
17
+ transformers==4.25.1