Spaces:
Build error
Build error
taquynhnga
commited on
Commit
•
18f2f54
1
Parent(s):
5d3c300
lfs track pickle & json & csv
Browse files- .gitattributes +1 -8
- .github/workflows/sync_to_huggingface_hub.yml +19 -3
- .gitignore +183 -3
- Home.py +13 -3
- README.md +18 -3
- backend/load_file.py +37 -3
- backend/maximally_activating_patches.py +43 -3
- backend/smooth_grad.py +233 -3
- backend/utils.py +334 -3
- data/ImageNet_metadata.csv +0 -0
- data/dot_architectures/convnext_architecture.dot +194 -3
- data/preprocessed_image_net/val_data_0.pkl +3 -0
- data/preprocessed_image_net/val_data_1.pkl +3 -0
- data/preprocessed_image_net/val_data_2.pkl +3 -0
- data/preprocessed_image_net/val_data_3.pkl +3 -0
- data/preprocessed_image_net/val_data_4.pkl +3 -0
- frontend/__init__.py +6 -3
- frontend/index.html +204 -3
- load_file.py +37 -3
- pages/1_Maximally_activating_patches.py +161 -3
- pages/2_SmoothGrad.py +145 -3
- pages/3_ImageNet1k.py +48 -3
- requirements.txt +17 -3
.gitattributes
CHANGED
@@ -1,10 +1,3 @@
|
|
1 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
2 |
*.json filter=lfs diff=lfs merge=lfs -text
|
3 |
-
|
4 |
-
*.yml filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.gitignore filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.py filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.csv filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.dot filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.html filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.txt filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
2 |
*.json filter=lfs diff=lfs merge=lfs -text
|
3 |
+
.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/workflows/sync_to_huggingface_hub.yml
CHANGED
@@ -1,3 +1,19 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
- name: Push to hub
|
17 |
+
env:
|
18 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
19 |
+
run: git push --force https://taquynhnga:[email protected]/spaces/taquynhnga/CNNs-interpretation-visualization main
|
.gitignore
CHANGED
@@ -1,3 +1,183 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VSCode
|
2 |
+
.vscode/*
|
3 |
+
!.vscode/settings.json
|
4 |
+
!.vscode/tasks.json
|
5 |
+
!.vscode/launch.json
|
6 |
+
!.vscode/extensions.json
|
7 |
+
*.code-workspace
|
8 |
+
# Local History for Visual Studio Code
|
9 |
+
.history/
|
10 |
+
|
11 |
+
# Common credential files
|
12 |
+
**/credentials.json
|
13 |
+
**/client_secrets.json
|
14 |
+
**/client_secret.json
|
15 |
+
*creds*
|
16 |
+
*.dat
|
17 |
+
*password*
|
18 |
+
*.httr-oauth*
|
19 |
+
|
20 |
+
# Private Node Modules
|
21 |
+
node_modules/
|
22 |
+
creds.js
|
23 |
+
|
24 |
+
# Private Files
|
25 |
+
# *.json
|
26 |
+
# *.csv
|
27 |
+
# *.csv.gz
|
28 |
+
# *.tsv
|
29 |
+
# *.tsv.gz
|
30 |
+
# *.xlsx
|
31 |
+
git-large-file
|
32 |
+
deta_drive.py
|
33 |
+
secret_keys.py
|
34 |
+
|
35 |
+
# Large files
|
36 |
+
# data/preprocessed_image_net/
|
37 |
+
# data/activation/*.pkl
|
38 |
+
# data/activation/*.json
|
39 |
+
# data/layer_infos/*.pkl
|
40 |
+
# data/layer_infos/*.json
|
41 |
+
|
42 |
+
# Mac/OSX
|
43 |
+
.DS_Store
|
44 |
+
|
45 |
+
|
46 |
+
# Byte-compiled / optimized / DLL files
|
47 |
+
__pycache__/
|
48 |
+
*.py[cod]
|
49 |
+
*$py.class
|
50 |
+
|
51 |
+
# C extensions
|
52 |
+
*.so
|
53 |
+
|
54 |
+
# Distribution / packaging
|
55 |
+
.Python
|
56 |
+
build/
|
57 |
+
develop-eggs/
|
58 |
+
dist/
|
59 |
+
downloads/
|
60 |
+
eggs/
|
61 |
+
.eggs/
|
62 |
+
lib/
|
63 |
+
lib64/
|
64 |
+
parts/
|
65 |
+
sdist/
|
66 |
+
var/
|
67 |
+
wheels/
|
68 |
+
share/python-wheels/
|
69 |
+
*.egg-info/
|
70 |
+
.installed.cfg
|
71 |
+
*.egg
|
72 |
+
MANIFEST
|
73 |
+
|
74 |
+
# PyInstaller
|
75 |
+
# Usually these files are written by a python script from a template
|
76 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
77 |
+
*.manifest
|
78 |
+
*.spec
|
79 |
+
|
80 |
+
# Installer logs
|
81 |
+
pip-log.txt
|
82 |
+
pip-delete-this-directory.txt
|
83 |
+
|
84 |
+
# Unit test / coverage reports
|
85 |
+
htmlcov/
|
86 |
+
.tox/
|
87 |
+
.nox/
|
88 |
+
.coverage
|
89 |
+
.coverage.*
|
90 |
+
.cache
|
91 |
+
nosetests.xml
|
92 |
+
coverage.xml
|
93 |
+
*.cover
|
94 |
+
*.py,cover
|
95 |
+
.hypothesis/
|
96 |
+
.pytest_cache/
|
97 |
+
cover/
|
98 |
+
|
99 |
+
# Translations
|
100 |
+
*.mo
|
101 |
+
*.pot
|
102 |
+
|
103 |
+
# Django stuff:
|
104 |
+
*.log
|
105 |
+
local_settings.py
|
106 |
+
db.sqlite3
|
107 |
+
db.sqlite3-journal
|
108 |
+
|
109 |
+
# Flask stuff:
|
110 |
+
instance/
|
111 |
+
.webassets-cache
|
112 |
+
|
113 |
+
# Scrapy stuff:
|
114 |
+
.scrapy
|
115 |
+
|
116 |
+
# Sphinx documentation
|
117 |
+
docs/_build/
|
118 |
+
|
119 |
+
# PyBuilder
|
120 |
+
.pybuilder/
|
121 |
+
target/
|
122 |
+
|
123 |
+
# Jupyter Notebook
|
124 |
+
.ipynb_checkpoints
|
125 |
+
|
126 |
+
# IPython
|
127 |
+
profile_default/
|
128 |
+
ipython_config.py
|
129 |
+
|
130 |
+
# pyenv
|
131 |
+
# For a library or package, you might want to ignore these files since the code is
|
132 |
+
# intended to run in multiple environments; otherwise, check them in:
|
133 |
+
# .python-version
|
134 |
+
|
135 |
+
# pipenv
|
136 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
137 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
138 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
139 |
+
# install all needed dependencies.
|
140 |
+
#Pipfile.lock
|
141 |
+
|
142 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
143 |
+
__pypackages__/
|
144 |
+
|
145 |
+
# Celery stuff
|
146 |
+
celerybeat-schedule
|
147 |
+
celerybeat.pid
|
148 |
+
|
149 |
+
# SageMath parsed files
|
150 |
+
*.sage.py
|
151 |
+
|
152 |
+
# Environments
|
153 |
+
.env
|
154 |
+
.venv
|
155 |
+
env/
|
156 |
+
venv/
|
157 |
+
ENV/
|
158 |
+
env.bak/
|
159 |
+
venv.bak/
|
160 |
+
|
161 |
+
# Spyder project settings
|
162 |
+
.spyderproject
|
163 |
+
.spyproject
|
164 |
+
|
165 |
+
# Rope project settings
|
166 |
+
.ropeproject
|
167 |
+
|
168 |
+
# mkdocs documentation
|
169 |
+
/site
|
170 |
+
|
171 |
+
# mypy
|
172 |
+
.mypy_cache/
|
173 |
+
.dmypy.json
|
174 |
+
dmypy.json
|
175 |
+
|
176 |
+
# Pyre type checker
|
177 |
+
.pyre/
|
178 |
+
|
179 |
+
# pytype static type analyzer
|
180 |
+
.pytype/
|
181 |
+
|
182 |
+
# Cython debug symbols
|
183 |
+
cython_debug/
|
Home.py
CHANGED
@@ -1,3 +1,13 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
6 |
+
import torch
|
7 |
+
|
8 |
+
st.set_page_config(layout='wide')
|
9 |
+
st.title('About')
|
10 |
+
|
11 |
+
st.write('Loaded 3 models')
|
12 |
+
|
13 |
+
|
README.md
CHANGED
@@ -1,3 +1,18 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: CNNs Interpretation Visualization
|
3 |
+
emoji: 🕯
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.10.0
|
8 |
+
app_file: Home.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# Visualizing Interpretations of CNN models: ConvNeXt, ResNet and MobileNet
|
13 |
+
|
14 |
+
To be change name: CNNs-interpretation-visualization
|
15 |
+
|
16 |
+
This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
|
17 |
+
|
18 |
+
This repo lacks one more folder `data/preprocessed_image_net` which contains 50,000 preprocessed imagenet validation images saved in 5 pickle files.
|
backend/load_file.py
CHANGED
@@ -1,3 +1,37 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
def load_pickle(filename):
|
7 |
+
with open(filename, 'rb') as file:
|
8 |
+
data = pickle.load(file)
|
9 |
+
return data
|
10 |
+
|
11 |
+
def save_pickle_to_json(filename):
|
12 |
+
ordered_dict = load_pickle(filename)
|
13 |
+
json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
|
14 |
+
with open(filename.replace('.pkl', '.json'), 'w') as f:
|
15 |
+
f.write(json_obj)
|
16 |
+
|
17 |
+
def load_json(filename):
|
18 |
+
with open(filename, 'r') as read_file:
|
19 |
+
loaded_dict = json.loads(read_file.read())
|
20 |
+
loaded_dict = OrderedDict(loaded_dict)
|
21 |
+
for k, v in loaded_dict.items():
|
22 |
+
loaded_dict[k] = np.asarray(v)
|
23 |
+
return loaded_dict
|
24 |
+
|
25 |
+
class NumpyEncoder(json.JSONEncoder):
|
26 |
+
def default(self, obj):
|
27 |
+
if isinstance(obj, np.ndarray):
|
28 |
+
return obj.tolist()
|
29 |
+
return json.JSONEncoder.default(self, obj)
|
30 |
+
|
31 |
+
# save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
|
32 |
+
# save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
|
33 |
+
# save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
|
34 |
+
|
35 |
+
file = load_json('data/layer_infos/convnext_layer_infos.json')
|
36 |
+
print(type(file))
|
37 |
+
print(type(file['embeddings.patch_embeddings']))
|
backend/maximally_activating_patches.py
CHANGED
@@ -1,3 +1,43 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from backend.load_file import load_json
|
5 |
+
|
6 |
+
|
7 |
+
@st.cache(allow_output_mutation=True)
|
8 |
+
def load_activation(filename):
|
9 |
+
activation = load_json(filename)
|
10 |
+
return activation
|
11 |
+
|
12 |
+
@st.cache(allow_output_mutation=True)
|
13 |
+
def load_dataset(data_index):
|
14 |
+
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
15 |
+
dataset = pickle.load(file)
|
16 |
+
return dataset
|
17 |
+
|
18 |
+
def load_layer_infos(filename):
|
19 |
+
layer_infos = load_json(filename)
|
20 |
+
return layer_infos
|
21 |
+
|
22 |
+
def get_receptive_field_coordinates(layer_infos, layer_name, idx_x, idx_y):
|
23 |
+
"""
|
24 |
+
layer_name: as in layer_infos keys (eg: 'encoder.stages[0].layers[0]')
|
25 |
+
idx_x: integer coordinate of width axis in feature maps. must < n
|
26 |
+
idx_y: integer coordinate of height axis in feature maps. must < n
|
27 |
+
"""
|
28 |
+
layer_name = layer_name.replace('.dwconv', '').replace('.layernorm', '')
|
29 |
+
layer_name = layer_name.replace('.pwconv1', '').replace('.pwconv2', '').replace('.drop_path', '')
|
30 |
+
n = layer_infos[layer_name]['n']
|
31 |
+
j = layer_infos[layer_name]['j']
|
32 |
+
r = layer_infos[layer_name]['r']
|
33 |
+
start = layer_infos[layer_name]['start']
|
34 |
+
assert idx_x < n, f'n={n}'
|
35 |
+
assert idx_y < n, f'n={n}'
|
36 |
+
|
37 |
+
# image tensor (N, H, W, C) or (N, C, H, W) => image_patch=image[y1:y2, x1:x2]
|
38 |
+
center = (start + idx_x*j, start + idx_y*j)
|
39 |
+
x1, x2 = (max(center[0]-r/2, 0), max(center[0]+r/2, 0))
|
40 |
+
y1, y2 = (max(center[1]-r/2, 0), max(center[1]+r/2, 0))
|
41 |
+
x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2)
|
42 |
+
|
43 |
+
return x1, x2, y1, y2
|
backend/smooth_grad.py
CHANGED
@@ -1,3 +1,233 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from matplotlib import pylab as P
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import TensorDataset
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
|
14 |
+
sys.path.append(dirpath_to_modules)
|
15 |
+
|
16 |
+
from torchvex.base import ExplanationMethod
|
17 |
+
from torchvex.utils.normalization import clamp_quantile
|
18 |
+
|
19 |
+
def ShowImage(im, title='', ax=None):
|
20 |
+
image = np.array(im)
|
21 |
+
return image
|
22 |
+
|
23 |
+
def ShowGrayscaleImage(im, title='', ax=None):
|
24 |
+
if ax is None:
|
25 |
+
P.figure()
|
26 |
+
P.axis('off')
|
27 |
+
P.imshow(im, cmap=P.cm.gray, vmin=0, vmax=1)
|
28 |
+
P.title(title)
|
29 |
+
return P
|
30 |
+
|
31 |
+
def ShowHeatMap(im, title='', ax=None):
|
32 |
+
im = im - im.min()
|
33 |
+
im = im / im.max()
|
34 |
+
im = im.clip(0,1)
|
35 |
+
im = np.uint8(im * 255)
|
36 |
+
|
37 |
+
im = cv2.resize(im, (224,224))
|
38 |
+
image = cv2.resize(im, (224, 224))
|
39 |
+
|
40 |
+
# Apply JET colormap
|
41 |
+
color_heatmap = cv2.applyColorMap(image, cv2.COLORMAP_HOT)
|
42 |
+
# P.imshow(im, cmap='inferno')
|
43 |
+
# P.title(title)
|
44 |
+
return color_heatmap
|
45 |
+
|
46 |
+
def ShowMaskedImage(saliency_map, image, title='', ax=None):
|
47 |
+
"""
|
48 |
+
Save saliency map on image.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
image: Tensor of size (H,W,3)
|
52 |
+
saliency_map: Tensor of size (H,W,1)
|
53 |
+
"""
|
54 |
+
|
55 |
+
# if ax is None:
|
56 |
+
# P.figure()
|
57 |
+
# P.axis('off')
|
58 |
+
|
59 |
+
saliency_map = saliency_map - saliency_map.min()
|
60 |
+
saliency_map = saliency_map / saliency_map.max()
|
61 |
+
saliency_map = saliency_map.clip(0,1)
|
62 |
+
saliency_map = np.uint8(saliency_map * 255)
|
63 |
+
|
64 |
+
saliency_map = cv2.resize(saliency_map, (224,224))
|
65 |
+
image = cv2.resize(image, (224, 224))
|
66 |
+
|
67 |
+
# Apply JET colormap
|
68 |
+
color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_HOT)
|
69 |
+
|
70 |
+
# Blend image with heatmap
|
71 |
+
img_with_heatmap = cv2.addWeighted(image, 0.4, color_heatmap, 0.6, 0)
|
72 |
+
|
73 |
+
# P.imshow(img_with_heatmap)
|
74 |
+
# P.title(title)
|
75 |
+
return img_with_heatmap
|
76 |
+
|
77 |
+
def LoadImage(file_path):
|
78 |
+
im = PIL.Image.open(file_path)
|
79 |
+
im = im.resize((224, 224))
|
80 |
+
im = np.asarray(im)
|
81 |
+
return im
|
82 |
+
|
83 |
+
|
84 |
+
def visualize_image_grayscale(image_3d, percentile=99):
|
85 |
+
r"""Returns a 3D tensor as a grayscale 2D tensor.
|
86 |
+
This method sums a 3D tensor across the absolute value of axis=2, and then
|
87 |
+
clips values at a given percentile.
|
88 |
+
"""
|
89 |
+
image_2d = np.sum(np.abs(image_3d), axis=2)
|
90 |
+
|
91 |
+
vmax = np.percentile(image_2d, percentile)
|
92 |
+
vmin = np.min(image_2d)
|
93 |
+
|
94 |
+
return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1)
|
95 |
+
|
96 |
+
def visualize_image_diverging(image_3d, percentile=99):
|
97 |
+
r"""Returns a 3D tensor as a 2D tensor with positive and negative values.
|
98 |
+
"""
|
99 |
+
image_2d = np.sum(image_3d, axis=2)
|
100 |
+
|
101 |
+
span = abs(np.percentile(image_2d, percentile))
|
102 |
+
vmin = -span
|
103 |
+
vmax = span
|
104 |
+
|
105 |
+
return np.clip((image_2d - vmin) / (vmax - vmin), -1, 1)
|
106 |
+
|
107 |
+
|
108 |
+
class SimpleGradient(ExplanationMethod):
|
109 |
+
def __init__(self, model, create_graph=False,
|
110 |
+
preprocess=None, postprocess=None):
|
111 |
+
super().__init__(model, preprocess, postprocess)
|
112 |
+
self.create_graph = create_graph
|
113 |
+
|
114 |
+
def predict(self, x):
|
115 |
+
return self.model(x)
|
116 |
+
|
117 |
+
@torch.enable_grad()
|
118 |
+
def process(self, inputs, target):
|
119 |
+
self.model.zero_grad()
|
120 |
+
inputs.requires_grad_(True)
|
121 |
+
|
122 |
+
out = self.model(inputs)
|
123 |
+
out = out if type(out) == torch.Tensor else out.logits
|
124 |
+
|
125 |
+
num_classes = out.size(-1)
|
126 |
+
onehot = torch.zeros(inputs.size(0), num_classes, *target.shape[1:])
|
127 |
+
onehot = onehot.to(dtype=inputs.dtype, device=inputs.device)
|
128 |
+
onehot.scatter_(1, target.unsqueeze(1), 1)
|
129 |
+
|
130 |
+
grad, = torch.autograd.grad(
|
131 |
+
(out*onehot).sum(), inputs, create_graph=self.create_graph
|
132 |
+
)
|
133 |
+
|
134 |
+
return grad
|
135 |
+
|
136 |
+
|
137 |
+
class SmoothGradient(ExplanationMethod):
|
138 |
+
def __init__(self, model, stdev_spread=0.15, num_samples=25,
|
139 |
+
magnitude=True, batch_size=-1,
|
140 |
+
create_graph=False, preprocess=None, postprocess=None):
|
141 |
+
super().__init__(model, preprocess, postprocess)
|
142 |
+
self.stdev_spread = stdev_spread
|
143 |
+
self.nsample = num_samples
|
144 |
+
self.create_graph = create_graph
|
145 |
+
self.magnitude = magnitude
|
146 |
+
self.batch_size = batch_size
|
147 |
+
if self.batch_size == -1:
|
148 |
+
self.batch_size = self.nsample
|
149 |
+
|
150 |
+
self._simgrad = SimpleGradient(model, create_graph)
|
151 |
+
|
152 |
+
def process(self, inputs, target):
|
153 |
+
self.model.zero_grad()
|
154 |
+
|
155 |
+
maxima = inputs.flatten(1).max(-1)[0]
|
156 |
+
minima = inputs.flatten(1).min(-1)[0]
|
157 |
+
|
158 |
+
stdev = self.stdev_spread * (maxima - minima).cpu()
|
159 |
+
stdev = stdev.view(inputs.size(0), 1, 1, 1).expand_as(inputs)
|
160 |
+
stdev = stdev.unsqueeze(0).expand(self.nsample, *[-1]*4)
|
161 |
+
noise = torch.normal(0, stdev)
|
162 |
+
|
163 |
+
target_expanded = target.unsqueeze(0).cpu()
|
164 |
+
target_expanded = target_expanded.expand(noise.size(0), -1)
|
165 |
+
|
166 |
+
noiseloader = torch.utils.data.DataLoader(
|
167 |
+
TensorDataset(noise, target_expanded), batch_size=self.batch_size
|
168 |
+
)
|
169 |
+
|
170 |
+
total_gradients = torch.zeros_like(inputs)
|
171 |
+
for noise, t_exp in noiseloader:
|
172 |
+
inputs_w_noise = inputs.unsqueeze(0) + noise.to(inputs.device)
|
173 |
+
inputs_w_noise = inputs_w_noise.view(-1, *inputs.shape[1:])
|
174 |
+
gradients = self._simgrad(inputs_w_noise, t_exp.view(-1))
|
175 |
+
gradients = gradients.view(self.batch_size, *inputs.shape)
|
176 |
+
if self.magnitude:
|
177 |
+
gradients = gradients.pow(2)
|
178 |
+
total_gradients = total_gradients + gradients.sum(0)
|
179 |
+
|
180 |
+
smoothed_gradient = total_gradients / self.nsample
|
181 |
+
return smoothed_gradient
|
182 |
+
|
183 |
+
|
184 |
+
def feed_forward(model_name, image, model=None, feature_extractor=None):
|
185 |
+
if model_name in ['ConvNeXt', 'ResNet']:
|
186 |
+
inputs = feature_extractor(image, return_tensors="pt")
|
187 |
+
logits = model(**inputs).logits
|
188 |
+
prediction_class = logits.argmax(-1).item()
|
189 |
+
else:
|
190 |
+
transform_images = transforms.Compose([
|
191 |
+
transforms.Resize(224),
|
192 |
+
transforms.ToTensor(),
|
193 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
194 |
+
input_tensor = transform_images(image)
|
195 |
+
inputs = input_tensor.unsqueeze(0)
|
196 |
+
|
197 |
+
output = model(inputs)
|
198 |
+
prediction_class = output.argmax(-1).item()
|
199 |
+
#prediction_label = model.config.id2label[prediction_class]
|
200 |
+
return inputs, prediction_class
|
201 |
+
|
202 |
+
def clip_gradient(gradient):
|
203 |
+
gradient = gradient.abs().sum(1, keepdim=True)
|
204 |
+
return clamp_quantile(gradient, q=0.99)
|
205 |
+
|
206 |
+
def fig2img(fig):
|
207 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
208 |
+
import io
|
209 |
+
buf = io.BytesIO()
|
210 |
+
fig.savefig(buf)
|
211 |
+
buf.seek(0)
|
212 |
+
img = Image.open(buf)
|
213 |
+
return img
|
214 |
+
|
215 |
+
def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25):
|
216 |
+
inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
|
217 |
+
|
218 |
+
smoothgrad_gen = SmoothGradient(
|
219 |
+
model, num_samples=num_samples, stdev_spread=0.1,
|
220 |
+
magnitude=False, postprocess=clip_gradient)
|
221 |
+
|
222 |
+
if type(inputs) != torch.Tensor:
|
223 |
+
inputs = inputs['pixel_values']
|
224 |
+
|
225 |
+
smoothgrad_mask = smoothgrad_gen(inputs, prediction_class)
|
226 |
+
smoothgrad_mask = smoothgrad_mask[0].numpy()
|
227 |
+
smoothgrad_mask = np.transpose(smoothgrad_mask, (1, 2, 0))
|
228 |
+
|
229 |
+
image = np.asarray(image)
|
230 |
+
# ori_image = ShowImage(image)
|
231 |
+
heat_map_image = ShowHeatMap(smoothgrad_mask)
|
232 |
+
masked_image = ShowMaskedImage(smoothgrad_mask, image)
|
233 |
+
return heat_map_image, masked_image
|
backend/utils.py
CHANGED
@@ -1,3 +1,334 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import io
|
5 |
+
from typing import List, Optional
|
6 |
+
|
7 |
+
import markdown
|
8 |
+
import matplotlib
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
import streamlit as st
|
13 |
+
from plotly import express as px
|
14 |
+
from plotly.subplots import make_subplots
|
15 |
+
from tqdm import trange
|
16 |
+
|
17 |
+
@st.cache(allow_output_mutation=True)
|
18 |
+
def load_dataset(data_index):
|
19 |
+
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
20 |
+
dataset = pickle.load(file)
|
21 |
+
return dataset
|
22 |
+
|
23 |
+
def load_dataset_dict():
|
24 |
+
dataset_dict = {}
|
25 |
+
progress_empty = st.empty()
|
26 |
+
text_empty = st.empty()
|
27 |
+
text_empty.write("Loading datasets...")
|
28 |
+
progress_bar = progress_empty.progress(0.0)
|
29 |
+
for data_index in trange(5):
|
30 |
+
dataset_dict[data_index] = load_dataset(data_index)
|
31 |
+
progress_bar.progress((data_index+1)/5)
|
32 |
+
progress_empty.empty()
|
33 |
+
text_empty.empty()
|
34 |
+
return dataset_dict
|
35 |
+
|
36 |
+
def make_grid(cols=None,rows=None):
|
37 |
+
grid = [0]*rows
|
38 |
+
for i in range(rows):
|
39 |
+
with st.container():
|
40 |
+
grid[i] = st.columns(cols)
|
41 |
+
return grid
|
42 |
+
|
43 |
+
|
44 |
+
def use_container_width_percentage(percentage_width:int = 75):
|
45 |
+
max_width_str = f"max-width: {percentage_width}%;"
|
46 |
+
st.markdown(f"""
|
47 |
+
<style>
|
48 |
+
.reportview-container .main .block-container{{{max_width_str}}}
|
49 |
+
</style>
|
50 |
+
""",
|
51 |
+
unsafe_allow_html=True,
|
52 |
+
)
|
53 |
+
|
54 |
+
matplotlib.use("Agg")
|
55 |
+
COLOR = "#31333f"
|
56 |
+
BACKGROUND_COLOR = "#ffffff"
|
57 |
+
|
58 |
+
|
59 |
+
def grid_demo():
|
60 |
+
"""Main function. Run this to run the app"""
|
61 |
+
st.sidebar.title("Layout and Style Experiments")
|
62 |
+
st.sidebar.header("Settings")
|
63 |
+
st.markdown(
|
64 |
+
"""
|
65 |
+
# Layout and Style Experiments
|
66 |
+
|
67 |
+
The basic question is: Can we create a multi-column dashboard with plots, numbers and text using
|
68 |
+
the [CSS Grid](https://gridbyexample.com/examples)?
|
69 |
+
|
70 |
+
Can we do it with a nice api?
|
71 |
+
Can have a dark theme?
|
72 |
+
"""
|
73 |
+
)
|
74 |
+
|
75 |
+
select_block_container_style()
|
76 |
+
add_resources_section()
|
77 |
+
|
78 |
+
# My preliminary idea of an API for generating a grid
|
79 |
+
with Grid("1 1 1", color=COLOR, background_color=BACKGROUND_COLOR) as grid:
|
80 |
+
grid.cell(
|
81 |
+
class_="a",
|
82 |
+
grid_column_start=2,
|
83 |
+
grid_column_end=3,
|
84 |
+
grid_row_start=1,
|
85 |
+
grid_row_end=2,
|
86 |
+
).markdown("# This is A Markdown Cell")
|
87 |
+
grid.cell("b", 2, 3, 2, 3).text("The cell to the left is a dataframe")
|
88 |
+
grid.cell("c", 3, 4, 2, 3).plotly_chart(get_plotly_fig())
|
89 |
+
grid.cell("d", 1, 2, 1, 3).dataframe(get_dataframe())
|
90 |
+
grid.cell("e", 3, 4, 1, 2).markdown(
|
91 |
+
"Try changing the **block container style** in the sidebar!"
|
92 |
+
)
|
93 |
+
grid.cell("f", 1, 3, 3, 4).text(
|
94 |
+
"The cell to the right is a matplotlib svg image"
|
95 |
+
)
|
96 |
+
grid.cell("g", 3, 4, 3, 4).pyplot(get_matplotlib_plt())
|
97 |
+
|
98 |
+
|
99 |
+
def add_resources_section():
|
100 |
+
"""Adds a resources section to the sidebar"""
|
101 |
+
st.sidebar.header("Add_resources_section")
|
102 |
+
st.sidebar.markdown(
|
103 |
+
"""
|
104 |
+
- [gridbyexample.com] (https://gridbyexample.com/examples/)
|
105 |
+
"""
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
class Cell:
|
110 |
+
"""A Cell can hold text, markdown, plots etc."""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
class_: str = None,
|
115 |
+
grid_column_start: Optional[int] = None,
|
116 |
+
grid_column_end: Optional[int] = None,
|
117 |
+
grid_row_start: Optional[int] = None,
|
118 |
+
grid_row_end: Optional[int] = None,
|
119 |
+
):
|
120 |
+
self.class_ = class_
|
121 |
+
self.grid_column_start = grid_column_start
|
122 |
+
self.grid_column_end = grid_column_end
|
123 |
+
self.grid_row_start = grid_row_start
|
124 |
+
self.grid_row_end = grid_row_end
|
125 |
+
self.inner_html = ""
|
126 |
+
|
127 |
+
def _to_style(self) -> str:
|
128 |
+
return f"""
|
129 |
+
.{self.class_} {{
|
130 |
+
grid-column-start: {self.grid_column_start};
|
131 |
+
grid-column-end: {self.grid_column_end};
|
132 |
+
grid-row-start: {self.grid_row_start};
|
133 |
+
grid-row-end: {self.grid_row_end};
|
134 |
+
}}
|
135 |
+
"""
|
136 |
+
|
137 |
+
def text(self, text: str = ""):
|
138 |
+
self.inner_html = text
|
139 |
+
|
140 |
+
def markdown(self, text):
|
141 |
+
self.inner_html = markdown.markdown(text)
|
142 |
+
|
143 |
+
def dataframe(self, dataframe: pd.DataFrame):
|
144 |
+
self.inner_html = dataframe.to_html()
|
145 |
+
|
146 |
+
def plotly_chart(self, fig):
|
147 |
+
self.inner_html = f"""
|
148 |
+
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
149 |
+
<body>
|
150 |
+
<p>This should have been a plotly plot.
|
151 |
+
But since *script* tags are removed when inserting MarkDown/ HTML i cannot get it to workto work.
|
152 |
+
But I could potentially save to svg and insert that.</p>
|
153 |
+
<div id='divPlotly'></div>
|
154 |
+
<script>
|
155 |
+
var plotly_data = {fig.to_json()}
|
156 |
+
Plotly.react('divPlotly', plotly_data.data, plotly_data.layout);
|
157 |
+
</script>
|
158 |
+
</body>
|
159 |
+
"""
|
160 |
+
|
161 |
+
def pyplot(self, fig=None, **kwargs):
|
162 |
+
string_io = io.StringIO()
|
163 |
+
plt.savefig(string_io, format="svg", fig=(2, 2))
|
164 |
+
svg = string_io.getvalue()[215:]
|
165 |
+
plt.close(fig)
|
166 |
+
self.inner_html = '<div height="200px">' + svg + "</div>"
|
167 |
+
|
168 |
+
def _to_html(self):
|
169 |
+
return f"""<div class="box {self.class_}">{self.inner_html}</div>"""
|
170 |
+
|
171 |
+
|
172 |
+
class Grid:
|
173 |
+
"""A (CSS) Grid"""
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
template_columns="1 1 1",
|
178 |
+
gap="10px",
|
179 |
+
background_color=COLOR,
|
180 |
+
color=BACKGROUND_COLOR,
|
181 |
+
):
|
182 |
+
self.template_columns = template_columns
|
183 |
+
self.gap = gap
|
184 |
+
self.background_color = background_color
|
185 |
+
self.color = color
|
186 |
+
self.cells: List[Cell] = []
|
187 |
+
|
188 |
+
def __enter__(self):
|
189 |
+
return self
|
190 |
+
|
191 |
+
def __exit__(self, type, value, traceback):
|
192 |
+
st.markdown(self._get_grid_style(), unsafe_allow_html=True)
|
193 |
+
st.markdown(self._get_cells_style(), unsafe_allow_html=True)
|
194 |
+
st.markdown(self._get_cells_html(), unsafe_allow_html=True)
|
195 |
+
|
196 |
+
def _get_grid_style(self):
|
197 |
+
return f"""
|
198 |
+
<style>
|
199 |
+
.wrapper {{
|
200 |
+
display: grid;
|
201 |
+
grid-template-columns: {self.template_columns};
|
202 |
+
grid-gap: {self.gap};
|
203 |
+
background-color: {self.color};
|
204 |
+
color: {self.background_color};
|
205 |
+
}}
|
206 |
+
.box {{
|
207 |
+
background-color: {self.color};
|
208 |
+
color: {self.background_color};
|
209 |
+
border-radius: 0px;
|
210 |
+
padding: 0px;
|
211 |
+
font-size: 100%;
|
212 |
+
text-align: center;
|
213 |
+
}}
|
214 |
+
table {{
|
215 |
+
color: {self.color}
|
216 |
+
}}
|
217 |
+
</style>
|
218 |
+
"""
|
219 |
+
|
220 |
+
def _get_cells_style(self):
|
221 |
+
return (
|
222 |
+
"<style>"
|
223 |
+
+ "\n".join([cell._to_style() for cell in self.cells])
|
224 |
+
+ "</style>"
|
225 |
+
)
|
226 |
+
|
227 |
+
def _get_cells_html(self):
|
228 |
+
return (
|
229 |
+
'<div class="wrapper">'
|
230 |
+
+ "\n".join([cell._to_html() for cell in self.cells])
|
231 |
+
+ "</div>"
|
232 |
+
)
|
233 |
+
|
234 |
+
def cell(
|
235 |
+
self,
|
236 |
+
class_: str = None,
|
237 |
+
grid_column_start: Optional[int] = None,
|
238 |
+
grid_column_end: Optional[int] = None,
|
239 |
+
grid_row_start: Optional[int] = None,
|
240 |
+
grid_row_end: Optional[int] = None,
|
241 |
+
):
|
242 |
+
cell = Cell(
|
243 |
+
class_=class_,
|
244 |
+
grid_column_start=grid_column_start,
|
245 |
+
grid_column_end=grid_column_end,
|
246 |
+
grid_row_start=grid_row_start,
|
247 |
+
grid_row_end=grid_row_end,
|
248 |
+
)
|
249 |
+
self.cells.append(cell)
|
250 |
+
return cell
|
251 |
+
|
252 |
+
|
253 |
+
def select_block_container_style():
|
254 |
+
"""Add selection section for setting setting the max-width and padding
|
255 |
+
of the main block container"""
|
256 |
+
st.sidebar.header("Block Container Style")
|
257 |
+
max_width_100_percent = st.sidebar.checkbox("Max-width: 100%?", False)
|
258 |
+
if not max_width_100_percent:
|
259 |
+
max_width = st.sidebar.slider("Select max-width in px", 100, 2000, 1200, 100)
|
260 |
+
else:
|
261 |
+
max_width = 1200
|
262 |
+
dark_theme = st.sidebar.checkbox("Dark Theme?", False)
|
263 |
+
padding_top = st.sidebar.number_input("Select padding top in rem", 0, 200, 5, 1)
|
264 |
+
padding_right = st.sidebar.number_input("Select padding right in rem", 0, 200, 1, 1)
|
265 |
+
padding_left = st.sidebar.number_input("Select padding left in rem", 0, 200, 1, 1)
|
266 |
+
padding_bottom = st.sidebar.number_input(
|
267 |
+
"Select padding bottom in rem", 0, 200, 10, 1
|
268 |
+
)
|
269 |
+
if dark_theme:
|
270 |
+
global COLOR
|
271 |
+
global BACKGROUND_COLOR
|
272 |
+
BACKGROUND_COLOR = "rgb(17,17,17)"
|
273 |
+
COLOR = "#fff"
|
274 |
+
|
275 |
+
_set_block_container_style(
|
276 |
+
max_width,
|
277 |
+
max_width_100_percent,
|
278 |
+
padding_top,
|
279 |
+
padding_right,
|
280 |
+
padding_left,
|
281 |
+
padding_bottom,
|
282 |
+
)
|
283 |
+
|
284 |
+
|
285 |
+
def _set_block_container_style(
|
286 |
+
max_width: int = 1200,
|
287 |
+
max_width_100_percent: bool = False,
|
288 |
+
padding_top: int = 5,
|
289 |
+
padding_right: int = 1,
|
290 |
+
padding_left: int = 1,
|
291 |
+
padding_bottom: int = 10,
|
292 |
+
):
|
293 |
+
if max_width_100_percent:
|
294 |
+
max_width_str = f"max-width: 100%;"
|
295 |
+
else:
|
296 |
+
max_width_str = f"max-width: {max_width}px;"
|
297 |
+
st.markdown(
|
298 |
+
f"""
|
299 |
+
<style>
|
300 |
+
.reportview-container .main .block-container{{
|
301 |
+
{max_width_str}
|
302 |
+
padding-top: {padding_top}rem;
|
303 |
+
padding-right: {padding_right}rem;
|
304 |
+
padding-left: {padding_left}rem;
|
305 |
+
padding-bottom: {padding_bottom}rem;
|
306 |
+
}}
|
307 |
+
.reportview-container .main {{
|
308 |
+
color: {COLOR};
|
309 |
+
background-color: {BACKGROUND_COLOR};
|
310 |
+
}}
|
311 |
+
</style>
|
312 |
+
""",
|
313 |
+
unsafe_allow_html=True,
|
314 |
+
)
|
315 |
+
|
316 |
+
|
317 |
+
@st.cache
|
318 |
+
def get_dataframe() -> pd.DataFrame():
|
319 |
+
"""Dummy DataFrame"""
|
320 |
+
data = [
|
321 |
+
{"quantity": 1, "price": 2},
|
322 |
+
{"quantity": 3, "price": 5},
|
323 |
+
{"quantity": 4, "price": 8},
|
324 |
+
]
|
325 |
+
return pd.DataFrame(data)
|
326 |
+
|
327 |
+
|
328 |
+
def get_plotly_fig():
|
329 |
+
"""Dummy Plotly Plot"""
|
330 |
+
return px.line(data_frame=get_dataframe(), x="quantity", y="price")
|
331 |
+
|
332 |
+
|
333 |
+
def get_matplotlib_plt():
|
334 |
+
get_dataframe().plot(kind="line", x="quantity", y="price", figsize=(5, 3))
|
data/ImageNet_metadata.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
data/dot_architectures/convnext_architecture.dot
CHANGED
@@ -1,3 +1,194 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
digraph convnext {
|
2 |
+
graph [labeljust=left ranksep=0.2]
|
3 |
+
node [fixedsize=true shape=box width=3]
|
4 |
+
edge [arrowhead=vee]
|
5 |
+
image [label=image shape=plaintext]
|
6 |
+
embeddings_conv [label="4x4 conv, 96, stride 4"]
|
7 |
+
image -> embeddings_conv
|
8 |
+
subgraph "cluster stage 1" {
|
9 |
+
graph [color="#e9f2f7" label="stage 1\l" style=filled]
|
10 |
+
node [color="#bcd9e7" style=filled]
|
11 |
+
subgraph "cluster stage 1 block 1" {
|
12 |
+
graph [color="#0f4158" label="block 1\l" style=dashed]
|
13 |
+
"stage 1 block 1 dwconv" [label="d7x7 conv, 96, stride 1"]
|
14 |
+
"stage 1 block 1 pwconv1" [label="1x1 conv, 384"]
|
15 |
+
"stage 1 block 1 pwconv2" [label="1x1 conv, 96"]
|
16 |
+
"stage 1 block 1 dwconv" -> "stage 1 block 1 pwconv1"
|
17 |
+
"stage 1 block 1 pwconv1" -> "stage 1 block 1 pwconv2"
|
18 |
+
}
|
19 |
+
subgraph "cluster stage 1 block 2" {
|
20 |
+
graph [color="#0f4158" label="block 2\l" style=dashed]
|
21 |
+
"stage 1 block 2 dwconv" [label="d7x7 conv, 96, stride 1"]
|
22 |
+
"stage 1 block 2 pwconv1" [label="1x1 conv, 384"]
|
23 |
+
"stage 1 block 2 pwconv2" [label="1x1 conv, 96"]
|
24 |
+
"stage 1 block 2 dwconv" -> "stage 1 block 2 pwconv1"
|
25 |
+
"stage 1 block 2 pwconv1" -> "stage 1 block 2 pwconv2"
|
26 |
+
}
|
27 |
+
subgraph "cluster stage 1 block 3" {
|
28 |
+
graph [color="#0f4158" label="block 3\l" style=dashed]
|
29 |
+
"stage 1 block 3 dwconv" [label="d7x7 conv, 96, stride 1"]
|
30 |
+
"stage 1 block 3 pwconv1" [label="1x1 conv, 384"]
|
31 |
+
"stage 1 block 3 pwconv2" [label="1x1 conv, 96"]
|
32 |
+
"stage 1 block 3 dwconv" -> "stage 1 block 3 pwconv1"
|
33 |
+
"stage 1 block 3 pwconv1" -> "stage 1 block 3 pwconv2"
|
34 |
+
}
|
35 |
+
"stage 1 block 1 pwconv2" -> "stage 1 block 2 dwconv"
|
36 |
+
"stage 1 block 2 pwconv2" -> "stage 1 block 3 dwconv"
|
37 |
+
}
|
38 |
+
subgraph "cluster stage 2" {
|
39 |
+
graph [color="#e9f7f3" label="stage 2\l" style=filled]
|
40 |
+
node [color="#bce7db" style=filled]
|
41 |
+
"stage 2 downsampling" [label="2x2 conv, 192, stride 2"]
|
42 |
+
subgraph "cluster stage 2 block 1" {
|
43 |
+
graph [color="#0f5844" label="block 1\l" style=dashed]
|
44 |
+
"stage 2 block 1 dwconv" [label="d7x7 conv, 192, stride 1"]
|
45 |
+
"stage 2 block 1 pwconv1" [label="1x1 conv, 768"]
|
46 |
+
"stage 2 block 1 pwconv2" [label="1x1 conv, 192"]
|
47 |
+
"stage 2 block 1 dwconv" -> "stage 2 block 1 pwconv1"
|
48 |
+
"stage 2 block 1 pwconv1" -> "stage 2 block 1 pwconv2"
|
49 |
+
}
|
50 |
+
subgraph "cluster stage 2 block 2" {
|
51 |
+
graph [color="#0f5844" label="block 2\l" style=dashed]
|
52 |
+
"stage 2 block 2 dwconv" [label="d7x7 conv, 192, stride 1"]
|
53 |
+
"stage 2 block 2 pwconv1" [label="1x1 conv, 768"]
|
54 |
+
"stage 2 block 2 pwconv2" [label="1x1 conv, 192"]
|
55 |
+
"stage 2 block 2 dwconv" -> "stage 2 block 2 pwconv1"
|
56 |
+
"stage 2 block 2 pwconv1" -> "stage 2 block 2 pwconv2"
|
57 |
+
}
|
58 |
+
subgraph "cluster stage 2 block 3" {
|
59 |
+
graph [color="#0f5844" label="block 3\l" style=dashed]
|
60 |
+
"stage 2 block 3 dwconv" [label="d7x7 conv, 192, stride 1"]
|
61 |
+
"stage 2 block 3 pwconv1" [label="1x1 conv, 768"]
|
62 |
+
"stage 2 block 3 pwconv2" [label="1x1 conv, 192"]
|
63 |
+
"stage 2 block 3 dwconv" -> "stage 2 block 3 pwconv1"
|
64 |
+
"stage 2 block 3 pwconv1" -> "stage 2 block 3 pwconv2"
|
65 |
+
}
|
66 |
+
"stage 2 downsampling" -> "stage 2 block 1 dwconv"
|
67 |
+
"stage 2 block 1 pwconv2" -> "stage 2 block 2 dwconv"
|
68 |
+
"stage 2 block 2 pwconv2" -> "stage 2 block 3 dwconv"
|
69 |
+
}
|
70 |
+
subgraph "cluster stage 3" {
|
71 |
+
graph [color="#e9f2f7" label="stage 3\l" style=filled]
|
72 |
+
node [color="#bcd9e7" style=filled]
|
73 |
+
"stage 3 downsampling" [label="2x2 conv, 384, stride 2"]
|
74 |
+
subgraph "cluster stage 3 block 1" {
|
75 |
+
graph [color="#0f4158" label="block 1\l" style=dashed]
|
76 |
+
"stage 3 block 1 dwconv" [label="d7x7 conv, 384, stride 1"]
|
77 |
+
"stage 3 block 1 pwconv1" [label="1x1 conv, 1536"]
|
78 |
+
"stage 3 block 1 pwconv2" [label="1x1 conv, 384"]
|
79 |
+
"stage 3 block 1 dwconv" -> "stage 3 block 1 pwconv1"
|
80 |
+
"stage 3 block 1 pwconv1" -> "stage 3 block 1 pwconv2"
|
81 |
+
}
|
82 |
+
subgraph "cluster stage 3 block 2" {
|
83 |
+
graph [color="#0f4158" label="block 2\l" style=dashed]
|
84 |
+
"stage 3 block 2 dwconv" [label="d7x7 conv, 384, stride 1"]
|
85 |
+
"stage 3 block 2 pwconv1" [label="1x1 conv, 1536"]
|
86 |
+
"stage 3 block 2 pwconv2" [label="1x1 conv, 384"]
|
87 |
+
"stage 3 block 2 dwconv" -> "stage 3 block 2 pwconv1"
|
88 |
+
"stage 3 block 2 pwconv1" -> "stage 3 block 2 pwconv2"
|
89 |
+
}
|
90 |
+
subgraph "cluster stage 3 block 3" {
|
91 |
+
graph [color="#0f4158" label="block 3\l" style=dashed]
|
92 |
+
"stage 3 block 3 dwconv" [label="d7x7 conv, 384, stride 1"]
|
93 |
+
"stage 3 block 3 pwconv1" [label="1x1 conv, 1536"]
|
94 |
+
"stage 3 block 3 pwconv2" [label="1x1 conv, 384"]
|
95 |
+
"stage 3 block 3 dwconv" -> "stage 3 block 3 pwconv1"
|
96 |
+
"stage 3 block 3 pwconv1" -> "stage 3 block 3 pwconv2"
|
97 |
+
}
|
98 |
+
subgraph "cluster stage 3 block 4" {
|
99 |
+
graph [color="#0f4158" label="block 4\l" style=dashed]
|
100 |
+
"stage 3 block 4 dwconv" [label="d7x7 conv, 384, stride 1"]
|
101 |
+
"stage 3 block 4 pwconv1" [label="1x1 conv, 1536"]
|
102 |
+
"stage 3 block 4 pwconv2" [label="1x1 conv, 384"]
|
103 |
+
"stage 3 block 4 dwconv" -> "stage 3 block 4 pwconv1"
|
104 |
+
"stage 3 block 4 pwconv1" -> "stage 3 block 4 pwconv2"
|
105 |
+
}
|
106 |
+
subgraph "cluster stage 3 block 5" {
|
107 |
+
graph [color="#0f4158" label="block 5\l" style=dashed]
|
108 |
+
"stage 3 block 5 dwconv" [label="d7x7 conv, 384, stride 1"]
|
109 |
+
"stage 3 block 5 pwconv1" [label="1x1 conv, 1536"]
|
110 |
+
"stage 3 block 5 pwconv2" [label="1x1 conv, 384"]
|
111 |
+
"stage 3 block 5 dwconv" -> "stage 3 block 5 pwconv1"
|
112 |
+
"stage 3 block 5 pwconv1" -> "stage 3 block 5 pwconv2"
|
113 |
+
}
|
114 |
+
subgraph "cluster stage 3 block 6" {
|
115 |
+
graph [color="#0f4158" label="block 6\l" style=dashed]
|
116 |
+
"stage 3 block 6 dwconv" [label="d7x7 conv, 384, stride 1"]
|
117 |
+
"stage 3 block 6 pwconv1" [label="1x1 conv, 1536"]
|
118 |
+
"stage 3 block 6 pwconv2" [label="1x1 conv, 384"]
|
119 |
+
"stage 3 block 6 dwconv" -> "stage 3 block 6 pwconv1"
|
120 |
+
"stage 3 block 6 pwconv1" -> "stage 3 block 6 pwconv2"
|
121 |
+
}
|
122 |
+
subgraph "cluster stage 3 block 7" {
|
123 |
+
graph [color="#0f4158" label="block 7\l" style=dashed]
|
124 |
+
"stage 3 block 7 dwconv" [label="d7x7 conv, 384, stride 1"]
|
125 |
+
"stage 3 block 7 pwconv1" [label="1x1 conv, 1536"]
|
126 |
+
"stage 3 block 7 pwconv2" [label="1x1 conv, 384"]
|
127 |
+
"stage 3 block 7 dwconv" -> "stage 3 block 7 pwconv1"
|
128 |
+
"stage 3 block 7 pwconv1" -> "stage 3 block 7 pwconv2"
|
129 |
+
}
|
130 |
+
subgraph "cluster stage 3 block 8" {
|
131 |
+
graph [color="#0f4158" label="block 8\l" style=dashed]
|
132 |
+
"stage 3 block 8 dwconv" [label="d7x7 conv, 384, stride 1"]
|
133 |
+
"stage 3 block 8 pwconv1" [label="1x1 conv, 1536"]
|
134 |
+
"stage 3 block 8 pwconv2" [label="1x1 conv, 384"]
|
135 |
+
"stage 3 block 8 dwconv" -> "stage 3 block 8 pwconv1"
|
136 |
+
"stage 3 block 8 pwconv1" -> "stage 3 block 8 pwconv2"
|
137 |
+
}
|
138 |
+
subgraph "cluster stage 3 block 9" {
|
139 |
+
graph [color="#0f4158" label="block 9\l" style=dashed]
|
140 |
+
"stage 3 block 9 dwconv" [label="d7x7 conv, 384, stride 1"]
|
141 |
+
"stage 3 block 9 pwconv1" [label="1x1 conv, 1536"]
|
142 |
+
"stage 3 block 9 pwconv2" [label="1x1 conv, 384"]
|
143 |
+
"stage 3 block 9 dwconv" -> "stage 3 block 9 pwconv1"
|
144 |
+
"stage 3 block 9 pwconv1" -> "stage 3 block 9 pwconv2"
|
145 |
+
}
|
146 |
+
"stage 3 downsampling" -> "stage 3 block 1 dwconv"
|
147 |
+
"stage 3 block 1 pwconv2" -> "stage 3 block 2 dwconv"
|
148 |
+
"stage 3 block 2 pwconv2" -> "stage 3 block 3 dwconv"
|
149 |
+
"stage 3 block 3 pwconv2" -> "stage 3 block 4 dwconv"
|
150 |
+
"stage 3 block 4 pwconv2" -> "stage 3 block 5 dwconv"
|
151 |
+
"stage 3 block 5 pwconv2" -> "stage 3 block 6 dwconv"
|
152 |
+
"stage 3 block 6 pwconv2" -> "stage 3 block 7 dwconv"
|
153 |
+
"stage 3 block 7 pwconv2" -> "stage 3 block 8 dwconv"
|
154 |
+
"stage 3 block 8 pwconv2" -> "stage 3 block 9 dwconv"
|
155 |
+
}
|
156 |
+
subgraph "cluster stage 4" {
|
157 |
+
graph [color="#e9f7f3" label="stage 4\l" style=filled]
|
158 |
+
node [color="#bce7db" style=filled]
|
159 |
+
"stage 4 downsampling" [label="2x2 conv, 768, stride 2"]
|
160 |
+
subgraph "cluster stage 4 block 1" {
|
161 |
+
graph [color="#0f5844" label="block 1\l" style=dashed]
|
162 |
+
"stage 4 block 1 dwconv" [label="d7x7 conv, 768, stride 1"]
|
163 |
+
"stage 4 block 1 pwconv1" [label="1x1 conv, 3072"]
|
164 |
+
"stage 4 block 1 pwconv2" [label="1x1 conv, 768"]
|
165 |
+
"stage 4 block 1 dwconv" -> "stage 4 block 1 pwconv1"
|
166 |
+
"stage 4 block 1 pwconv1" -> "stage 4 block 1 pwconv2"
|
167 |
+
}
|
168 |
+
subgraph "cluster stage 4 block 2" {
|
169 |
+
graph [color="#0f5844" label="block 2\l" style=dashed]
|
170 |
+
"stage 4 block 2 dwconv" [label="d7x7 conv, 768, stride 1"]
|
171 |
+
"stage 4 block 2 pwconv1" [label="1x1 conv, 3072"]
|
172 |
+
"stage 4 block 2 pwconv2" [label="1x1 conv, 768"]
|
173 |
+
"stage 4 block 2 dwconv" -> "stage 4 block 2 pwconv1"
|
174 |
+
"stage 4 block 2 pwconv1" -> "stage 4 block 2 pwconv2"
|
175 |
+
}
|
176 |
+
subgraph "cluster stage 4 block 3" {
|
177 |
+
graph [color="#0f5844" label="block 3\l" style=dashed]
|
178 |
+
"stage 4 block 3 dwconv" [label="d7x7 conv, 768, stride 1"]
|
179 |
+
"stage 4 block 3 pwconv1" [label="1x1 conv, 3072"]
|
180 |
+
"stage 4 block 3 pwconv2" [label="1x1 conv, 768"]
|
181 |
+
"stage 4 block 3 dwconv" -> "stage 4 block 3 pwconv1"
|
182 |
+
"stage 4 block 3 pwconv1" -> "stage 4 block 3 pwconv2"
|
183 |
+
}
|
184 |
+
"stage 4 downsampling" -> "stage 4 block 1 dwconv"
|
185 |
+
"stage 4 block 1 pwconv2" -> "stage 4 block 2 dwconv"
|
186 |
+
"stage 4 block 2 pwconv2" -> "stage 4 block 3 dwconv"
|
187 |
+
}
|
188 |
+
"stage 1 block 3 pwconv2" -> "stage 2 downsampling"
|
189 |
+
"stage 2 block 3 pwconv2" -> "stage 3 downsampling"
|
190 |
+
"stage 3 block 9 pwconv2" -> "stage 4 downsampling"
|
191 |
+
embeddings_conv -> "stage 1 block 1 dwconv"
|
192 |
+
"stage 4 block 3 pwconv2" -> "output vector"
|
193 |
+
"output vector" [label="output vector" shape=plaintext]
|
194 |
+
}
|
data/preprocessed_image_net/val_data_0.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2698bdc240555e2a46a40936df87275bc5852142d30e921ae0dad9289b0f576f
|
3 |
+
size 906108480
|
data/preprocessed_image_net/val_data_1.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21780d77e212695dbee84d6d2ad17a5a520bc1634f68e1c8fd120f069ad76da1
|
3 |
+
size 907109023
|
data/preprocessed_image_net/val_data_2.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2cfc83b78420baa1b2c3a8da92e7fba1f33443d506f483ecff13cdba2035ab3c
|
3 |
+
size 907435149
|
data/preprocessed_image_net/val_data_3.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f5e2c7cb4d6bae17fbd062a0b46f2cee457ad466b725f7bdf0f8426069cafee
|
3 |
+
size 906089333
|
data/preprocessed_image_net/val_data_4.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ed53c87ec8b9945db31f910eb44b7e3092324643de25ea53a99fc29137df854
|
3 |
+
size 905439763
|
frontend/__init__.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
1 |
+
import streamlit.components.v1 as components
|
2 |
+
|
3 |
+
on_click_graph = components.declare_component(
|
4 |
+
"on_click_graph",
|
5 |
+
path="./frontend"
|
6 |
+
)
|
frontend/index.html
CHANGED
@@ -1,3 +1,204 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<html>
|
2 |
+
|
3 |
+
<head>
|
4 |
+
<style type="text/css">
|
5 |
+
</style>
|
6 |
+
</head>
|
7 |
+
|
8 |
+
<!--
|
9 |
+
----------------------------------------------------
|
10 |
+
Your custom static HTML goes in the body:
|
11 |
+
-->
|
12 |
+
|
13 |
+
<body>
|
14 |
+
</body>
|
15 |
+
|
16 |
+
<script type="text/javascript">
|
17 |
+
// Helper function to send type and data messages to Streamlit client
|
18 |
+
|
19 |
+
const SET_COMPONENT_VALUE = "streamlit:setComponentValue"
|
20 |
+
const RENDER = "streamlit:render"
|
21 |
+
const COMPONENT_READY = "streamlit:componentReady"
|
22 |
+
const SET_FRAME_HEIGHT = "streamlit:setFrameHeight"
|
23 |
+
var HIGHTLIGHT_COLOR;
|
24 |
+
var original_colors;
|
25 |
+
|
26 |
+
function _sendMessage(type, data) {
|
27 |
+
// copy data into object
|
28 |
+
var outboundData = Object.assign({
|
29 |
+
isStreamlitMessage: true,
|
30 |
+
type: type,
|
31 |
+
}, data)
|
32 |
+
|
33 |
+
if (type == SET_COMPONENT_VALUE) {
|
34 |
+
console.log("_sendMessage data: ", SET_COMPONENT_VALUE)
|
35 |
+
// console.log("_sendMessage data: " + JSON.stringify(data))
|
36 |
+
// console.log("_sendMessage outboundData: " + JSON.stringify(outboundData))
|
37 |
+
}
|
38 |
+
|
39 |
+
window.parent.postMessage(outboundData, "*")
|
40 |
+
}
|
41 |
+
|
42 |
+
function initialize(pipeline) {
|
43 |
+
|
44 |
+
// Hook Streamlit's message events into a simple dispatcher of pipeline handlers
|
45 |
+
window.addEventListener("message", (event) => {
|
46 |
+
if (event.data.type == RENDER) {
|
47 |
+
// The event.data.args dict holds any JSON-serializable value
|
48 |
+
// sent from the Streamlit client. It is already deserialized.
|
49 |
+
pipeline.forEach(handler => {
|
50 |
+
handler(event.data.args)
|
51 |
+
})
|
52 |
+
}
|
53 |
+
})
|
54 |
+
|
55 |
+
_sendMessage(COMPONENT_READY, { apiVersion: 1 });
|
56 |
+
|
57 |
+
// Component should be mounted by Streamlit in an iframe, so try to autoset the iframe height.
|
58 |
+
window.addEventListener("load", () => {
|
59 |
+
window.setTimeout(function () {
|
60 |
+
setFrameHeight(document.documentElement.clientHeight)
|
61 |
+
}, 0)
|
62 |
+
})
|
63 |
+
|
64 |
+
// Optionally, if auto-height computation fails, you can manually set it
|
65 |
+
// (uncomment below)
|
66 |
+
setFrameHeight(0)
|
67 |
+
}
|
68 |
+
|
69 |
+
function setFrameHeight(height) {
|
70 |
+
_sendMessage(SET_FRAME_HEIGHT, { height: height })
|
71 |
+
}
|
72 |
+
|
73 |
+
// The `data` argument can be any JSON-serializable value.
|
74 |
+
function notifyHost(data) {
|
75 |
+
_sendMessage(SET_COMPONENT_VALUE, data)
|
76 |
+
}
|
77 |
+
|
78 |
+
function changeButtonColor(button, color) {
|
79 |
+
pol = button.querySelectorAll('polygon')[0]
|
80 |
+
pol.setAttribute('fill', color)
|
81 |
+
pol.setAttribute('stroke', color)
|
82 |
+
}
|
83 |
+
|
84 |
+
function getButtonColor(button) {
|
85 |
+
pol = button.querySelectorAll('polygon')[0]
|
86 |
+
return pol.getAttribute('fill')
|
87 |
+
}
|
88 |
+
// ----------------------------------------------------
|
89 |
+
// Your custom functionality for the component goes here:
|
90 |
+
|
91 |
+
function toggle(button) {
|
92 |
+
group = 'node'
|
93 |
+
let button_color;
|
94 |
+
nodes = window.parent.document.getElementsByClassName('node')
|
95 |
+
console.log("nodes.length = ", nodes.length)
|
96 |
+
// for (let i = 0; i < nodes.length; i++) {
|
97 |
+
// console.log(nodes.item(i))
|
98 |
+
// }
|
99 |
+
console.log("selected button ", button, button.getAttribute('class'), button.id)
|
100 |
+
|
101 |
+
for (let i = 0; i < nodes.length; i++) {
|
102 |
+
polygons = nodes.item(i).querySelectorAll('polygon')
|
103 |
+
if (polygons.length == 0) {
|
104 |
+
continue
|
105 |
+
}
|
106 |
+
if (button.id == nodes.item(i).id & button.getAttribute('class').includes("off")) {
|
107 |
+
button.setAttribute('class', group + " on")
|
108 |
+
button_color = original_colors[i]
|
109 |
+
|
110 |
+
} else if (button.id == nodes.item(i).id & button.getAttribute('class').includes("on")) {
|
111 |
+
button.setAttribute('class', group + " off")
|
112 |
+
button_color = original_colors[i]
|
113 |
+
} else if (button.id == nodes.item(i).id) {
|
114 |
+
button.setAttribute('class', group + " on")
|
115 |
+
button_color = original_colors[i]
|
116 |
+
|
117 |
+
} else if (button.id != nodes.item(i).id & nodes.item(i).getAttribute('class').includes("on")) {
|
118 |
+
nodes.item(i).className = group + " off"
|
119 |
+
} else {
|
120 |
+
nodes.item(i).className = group + " off"
|
121 |
+
}
|
122 |
+
}
|
123 |
+
|
124 |
+
nodes = window.parent.document.getElementsByClassName('node')
|
125 |
+
actions = []
|
126 |
+
for (let i = 0; i < nodes.length; i++) {
|
127 |
+
polygons = nodes.item(i).querySelectorAll('polygon')
|
128 |
+
if (polygons.length == 0) {
|
129 |
+
continue
|
130 |
+
}
|
131 |
+
btn = nodes.item(i)
|
132 |
+
ori_color = original_colors[i]
|
133 |
+
color = btn.querySelectorAll('polygon')[0].getAttribute('fill')
|
134 |
+
actions.push({ "action": btn.getAttribute("class").includes("on"), "original_color": ori_color, "color": color})
|
135 |
+
}
|
136 |
+
|
137 |
+
states = {}
|
138 |
+
states['choice'] = {
|
139 |
+
"node_title": button.querySelectorAll("title")[0].innerHTML,
|
140 |
+
"node_id": button.id,
|
141 |
+
"state": {
|
142 |
+
"action": button.getAttribute("class").includes("on"),
|
143 |
+
"original_color": button_color,
|
144 |
+
"color": button.querySelectorAll('polygon')[0].getAttribute('fill')
|
145 |
+
}
|
146 |
+
}
|
147 |
+
states["options"] = {"states": actions }
|
148 |
+
|
149 |
+
notifyHost({
|
150 |
+
value: states,
|
151 |
+
dataType: "json",
|
152 |
+
})
|
153 |
+
}
|
154 |
+
|
155 |
+
// ----------------------------------------------------
|
156 |
+
// Here you can customize a pipeline of handlers for
|
157 |
+
// inbound properties from the Streamlit client app
|
158 |
+
|
159 |
+
// Set initial value sent from Streamlit!
|
160 |
+
function initializeProps_Handler(props) {
|
161 |
+
HIGHTLIGHT_COLOR = props['hightlight_color']
|
162 |
+
original_colors = []
|
163 |
+
// nodes = document.getElementsByClassName('node')
|
164 |
+
nodes = window.parent.document.getElementsByClassName('node')
|
165 |
+
console.log(nodes)
|
166 |
+
for (let i = 0; i < nodes.length; i++) {
|
167 |
+
// color = nodes.item(i).getElementsByTagName('POLYGON')[0].getAttribute("fill")
|
168 |
+
// nodes.item(i).addEventListener("click", toggle)
|
169 |
+
polygons = nodes.item(i).querySelectorAll('polygon')
|
170 |
+
if (polygons.length == 0) {
|
171 |
+
original_colors.push('none')
|
172 |
+
continue
|
173 |
+
}
|
174 |
+
|
175 |
+
color = polygons[0].getAttribute("fill")
|
176 |
+
if (!nodes.item(i).hasAttribute('color')) {
|
177 |
+
nodes.item(i).setAttribute("color", color)
|
178 |
+
original_colors.push(color)
|
179 |
+
} else {
|
180 |
+
original_colors.push(nodes.item(i).getAttribute("color"))
|
181 |
+
}
|
182 |
+
nodes.item(i).addEventListener("click", function (event) {toggle(this)})
|
183 |
+
}
|
184 |
+
// console.log("original colors:", original_colors)
|
185 |
+
}
|
186 |
+
// Access values sent from Streamlit!
|
187 |
+
function dataUpdate_Handler(props) {
|
188 |
+
console.log('dataUpdate_Handler...........')
|
189 |
+
let msgLabel = document.getElementById("message_label")
|
190 |
+
}
|
191 |
+
// Simply log received data dictionary
|
192 |
+
function log_Handler(props) {
|
193 |
+
console.log("Received from Streamlit: " + JSON.stringify(props))
|
194 |
+
}
|
195 |
+
|
196 |
+
let pipeline = [initializeProps_Handler, dataUpdate_Handler, log_Handler]
|
197 |
+
|
198 |
+
// ----------------------------------------------------
|
199 |
+
// Finally, initialize component passing in pipeline
|
200 |
+
initialize(pipeline)
|
201 |
+
|
202 |
+
</script>
|
203 |
+
|
204 |
+
</html>
|
load_file.py
CHANGED
@@ -1,3 +1,37 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
def load_pickle(filename):
|
7 |
+
with open(filename, 'rb') as file:
|
8 |
+
data = pickle.load(file)
|
9 |
+
return data
|
10 |
+
|
11 |
+
def save_pickle_to_json(filename):
|
12 |
+
ordered_dict = load_pickle(filename)
|
13 |
+
json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
|
14 |
+
with open(filename.replace('.pkl', '.json'), 'w') as f:
|
15 |
+
f.write(json_obj)
|
16 |
+
|
17 |
+
def load_json(filename):
|
18 |
+
with open(filename, 'r') as read_file:
|
19 |
+
loaded_dict = json.loads(read_file.read())
|
20 |
+
loaded_dict = OrderedDict(loaded_dict)
|
21 |
+
for k, v in loaded_dict.items():
|
22 |
+
loaded_dict[k] = np.asarray(v)
|
23 |
+
return loaded_dict
|
24 |
+
|
25 |
+
class NumpyEncoder(json.JSONEncoder):
|
26 |
+
def default(self, obj):
|
27 |
+
if isinstance(obj, np.ndarray):
|
28 |
+
return obj.tolist()
|
29 |
+
return json.JSONEncoder.default(self, obj)
|
30 |
+
|
31 |
+
# save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
|
32 |
+
# save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
|
33 |
+
# save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
|
34 |
+
|
35 |
+
file = load_json('data/layer_infos/convnext_layer_infos.json')
|
36 |
+
print(type(file))
|
37 |
+
print(type(file['embeddings.patch_embeddings']))
|
pages/1_Maximally_activating_patches.py
CHANGED
@@ -1,3 +1,161 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from plotly.subplots import make_subplots
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
|
7 |
+
import graphviz
|
8 |
+
|
9 |
+
from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
|
10 |
+
from frontend import on_click_graph
|
11 |
+
from backend.utils import load_dataset_dict
|
12 |
+
|
13 |
+
HIGHTLIGHT_COLOR = '#e7bcc5'
|
14 |
+
st.set_page_config(layout='wide')
|
15 |
+
|
16 |
+
# -------------------------- LOAD DATASET ---------------------------------
|
17 |
+
dataset_dict = load_dataset_dict()
|
18 |
+
|
19 |
+
# -------------------------- LOAD GRAPH -----------------------------------
|
20 |
+
|
21 |
+
def load_dot_to_graph(filename):
|
22 |
+
dot = graphviz.Source.from_file(filename)
|
23 |
+
source_lines = str(dot).splitlines()
|
24 |
+
source_lines.pop(0)
|
25 |
+
source_lines.pop(-1)
|
26 |
+
graph = graphviz.Digraph()
|
27 |
+
graph.body += source_lines
|
28 |
+
return graph, dot
|
29 |
+
|
30 |
+
st.title('Maximally activating image patches')
|
31 |
+
st.write('Visualize image patches that maximize the activation of layers in three models: ConvNeXt, ResNet, MobileNet')
|
32 |
+
|
33 |
+
# st.header('ConvNeXt')
|
34 |
+
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
|
35 |
+
convnext_graph = load_dot_to_graph(convnext_dot_file)[0]
|
36 |
+
|
37 |
+
convnext_graph.graph_attr['size'] = '4,40'
|
38 |
+
|
39 |
+
# -------------------------- DISPLAY GRAPH -----------------------------------
|
40 |
+
|
41 |
+
def chosen_node_text(clicked_node_title):
|
42 |
+
clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
|
43 |
+
stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
|
44 |
+
block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
|
45 |
+
layer_id = clicked_node_title.split()[-1]
|
46 |
+
|
47 |
+
if 'embeddings' in layer_id:
|
48 |
+
display_text = 'Patchify layer'
|
49 |
+
activation_key = 'embeddings.patch_embeddings'
|
50 |
+
elif 'downsampling' in layer_id:
|
51 |
+
display_text = f'Stage {stage_id} > Downsampling layer'
|
52 |
+
activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
|
53 |
+
else:
|
54 |
+
display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
|
55 |
+
activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
|
56 |
+
return display_text, activation_key
|
57 |
+
|
58 |
+
|
59 |
+
props = {
|
60 |
+
'hightlight_color': HIGHTLIGHT_COLOR,
|
61 |
+
'initial_state': {
|
62 |
+
'group_1_header': 'Choose an option from group 1',
|
63 |
+
'group_2_header': 'Choose an option from group 2'
|
64 |
+
}
|
65 |
+
}
|
66 |
+
|
67 |
+
convnext_tab, resnet_tab, mobilenet_tab = st.tabs(['ConvNeXt', 'ResNet', 'MobileNet'])
|
68 |
+
|
69 |
+
with convnext_tab:
|
70 |
+
col1, col2 = st.columns((2,5))
|
71 |
+
col1.markdown("#### Architecture")
|
72 |
+
col1.write('')
|
73 |
+
col1.write('Click on a layer below to generate top-k maximally activating image patches')
|
74 |
+
col1.graphviz_chart(convnext_graph)
|
75 |
+
|
76 |
+
with col2:
|
77 |
+
st.markdown("#### Output")
|
78 |
+
nodes = on_click_graph(key='toggle_buttons', **props)
|
79 |
+
|
80 |
+
# -------------------------- DISPLAY OUTPUT -----------------------------------
|
81 |
+
|
82 |
+
if nodes != None:
|
83 |
+
clicked_node_title = nodes["choice"]["node_title"]
|
84 |
+
clicked_node_id = nodes["choice"]["node_id"]
|
85 |
+
display_text, activation_key = chosen_node_text(clicked_node_title)
|
86 |
+
col2.write(f'**Chosen layer:** {display_text}')
|
87 |
+
# col2.write(f'**Activation key:** {activation_key}')
|
88 |
+
|
89 |
+
hightlight_syle = f'''
|
90 |
+
<style>
|
91 |
+
div[data-stale]:has(iframe) {{
|
92 |
+
height: 0;
|
93 |
+
}}
|
94 |
+
#{clicked_node_id}>polygon {{
|
95 |
+
fill: {HIGHTLIGHT_COLOR};
|
96 |
+
stroke: {HIGHTLIGHT_COLOR};
|
97 |
+
}}
|
98 |
+
</style>
|
99 |
+
'''
|
100 |
+
col2.markdown(hightlight_syle, unsafe_allow_html=True)
|
101 |
+
|
102 |
+
with col2:
|
103 |
+
layer_infos = None
|
104 |
+
with st.form('top_k_form'):
|
105 |
+
activation_path = './data/activation/convnext_activation.json'
|
106 |
+
activation = load_activation(activation_path)
|
107 |
+
num_channels = activation[activation_key].shape[1]
|
108 |
+
|
109 |
+
top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
|
110 |
+
channel_start, channel_end = st.slider(
|
111 |
+
'Choose channel range of this layer (recommend to choose small range less than 30)',
|
112 |
+
1, num_channels, value=(1, 30))
|
113 |
+
summit_button = st.form_submit_button('Generate image patches')
|
114 |
+
if summit_button:
|
115 |
+
|
116 |
+
activation = activation[activation_key][:top_k,:,:]
|
117 |
+
layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
|
118 |
+
# st.write(channel_start, channel_end)
|
119 |
+
# st.write(activation.shape, activation.shape[1])
|
120 |
+
|
121 |
+
if layer_infos != None:
|
122 |
+
num_cols, num_rows = top_k, channel_end - channel_start + 1
|
123 |
+
# num_rows = activation.shape[1]
|
124 |
+
top_k_coor_max_ = activation
|
125 |
+
st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
|
126 |
+
|
127 |
+
for row in range(channel_start, channel_end+1):
|
128 |
+
if row == channel_start:
|
129 |
+
top_margin = 50
|
130 |
+
fig = make_subplots(
|
131 |
+
rows=1, cols=num_cols,
|
132 |
+
subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
|
133 |
+
else:
|
134 |
+
top_margin = 0
|
135 |
+
fig = make_subplots(rows=1, cols=num_cols)
|
136 |
+
for col in range(1, num_cols+1):
|
137 |
+
k, c = col-1, row-1
|
138 |
+
img_index = int(top_k_coor_max_[k, c, 3])
|
139 |
+
activation_value = top_k_coor_max_[k, c, 0]
|
140 |
+
img = dataset_dict[img_index//10_000][img_index%10_000]['image']
|
141 |
+
class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
|
142 |
+
class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
|
143 |
+
|
144 |
+
idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
|
145 |
+
x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
|
146 |
+
img = np.array(img)[y1:y2, x1:x2, :]
|
147 |
+
|
148 |
+
hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
|
149 |
+
fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
|
150 |
+
fig.update_xaxes(showticklabels=False, showgrid=False)
|
151 |
+
fig.update_yaxes(showticklabels=False, showgrid=False)
|
152 |
+
fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
|
153 |
+
fig.update_layout(showlegend=False, yaxis_title=row)
|
154 |
+
fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
|
155 |
+
fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
|
156 |
+
st.plotly_chart(fig, use_container_width=True)
|
157 |
+
|
158 |
+
|
159 |
+
else:
|
160 |
+
col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
|
161 |
+
col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
|
pages/2_SmoothGrad.py
CHANGED
@@ -1,3 +1,145 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
from backend.utils import make_grid, load_dataset
|
6 |
+
|
7 |
+
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
|
8 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from matplotlib.backends.backend_agg import RendererAgg
|
12 |
+
_lock = RendererAgg.lock
|
13 |
+
|
14 |
+
st.set_page_config(layout='wide')
|
15 |
+
BACKGROUND_COLOR = '#bcd0e7'
|
16 |
+
|
17 |
+
|
18 |
+
st.title('Feature attribution with SmoothGrad')
|
19 |
+
st.write('Which features are responsible for the current prediction? ')
|
20 |
+
|
21 |
+
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
22 |
+
|
23 |
+
# --------------------------- LOAD function -----------------------------
|
24 |
+
|
25 |
+
@st.cache(allow_output_mutation=True)
|
26 |
+
def load_images(image_ids):
|
27 |
+
images = []
|
28 |
+
for image_id in image_ids:
|
29 |
+
dataset = load_dataset(image_id//10000)
|
30 |
+
images.append(dataset[image_id%10000])
|
31 |
+
return images
|
32 |
+
|
33 |
+
@st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
|
34 |
+
def load_model(model_name):
|
35 |
+
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
36 |
+
if model_name == 'ResNet':
|
37 |
+
model_file_path = 'microsoft/resnet-50'
|
38 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
39 |
+
model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
40 |
+
model.eval()
|
41 |
+
elif model_name == 'ConvNeXt':
|
42 |
+
model_file_path = 'facebook/convnext-tiny-224'
|
43 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
44 |
+
model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
45 |
+
model.eval()
|
46 |
+
else:
|
47 |
+
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
|
48 |
+
model.eval()
|
49 |
+
feature_extractor = None
|
50 |
+
return model, feature_extractor
|
51 |
+
|
52 |
+
images = []
|
53 |
+
image_ids = []
|
54 |
+
# INPUT ------------------------------
|
55 |
+
st.header('Input')
|
56 |
+
with st.form('smooth_grad_form'):
|
57 |
+
st.markdown('**Model and Input Setting**')
|
58 |
+
selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
|
59 |
+
# selected_image_set = st.selectbox('Image set', ['Random set', 'User-defined set'])
|
60 |
+
selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
|
61 |
+
|
62 |
+
# if selected_image_set == 'Class set':
|
63 |
+
# class_labels = imagenet_df.ClassLabel.unique().tolist()
|
64 |
+
# class_labels.sort()
|
65 |
+
# selected_classes = st.multiselect('Class filter', options=['All'] + class_labels)
|
66 |
+
# if not ('All' in selected_classes or len(selected_classes) == 0):
|
67 |
+
# imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
|
68 |
+
# no_images = st.slider('Number of images', 1, len(imagenet_df), value=10)
|
69 |
+
# image_ids = random.sample(imagenet_df.index.tolist(), k=no_images)
|
70 |
+
|
71 |
+
|
72 |
+
# user_defined_button = st.form_submit_button('User-defined set')
|
73 |
+
# random_set_button = st.form_submit_button('Random set')
|
74 |
+
|
75 |
+
# if user_defined_button:
|
76 |
+
# text = st.text_area('Specific Image IDs', value='0')
|
77 |
+
# image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
|
78 |
+
# if random_set_button:
|
79 |
+
# no_images = st.slider('Number of images', 1, 50, value=10)
|
80 |
+
# image_ids = random.sample(list(range(50_000)), k=no_images)
|
81 |
+
|
82 |
+
summit_button = st.form_submit_button('Set')
|
83 |
+
if summit_button:
|
84 |
+
setting_container = st.container()
|
85 |
+
# for id in image_ids:
|
86 |
+
# images = load_images(image_ids)
|
87 |
+
|
88 |
+
with st.form('2nd_form'):
|
89 |
+
st.markdown('**Image set setting**')
|
90 |
+
if selected_image_set == 'Random set':
|
91 |
+
no_images = st.slider('Number of images', 1, 50, value=10)
|
92 |
+
image_ids = random.sample(list(range(50_000)), k=no_images)
|
93 |
+
else:
|
94 |
+
text = st.text_area('Specific Image IDs', value='0')
|
95 |
+
image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
|
96 |
+
|
97 |
+
run_button = st.form_submit_button('Display output')
|
98 |
+
if run_button:
|
99 |
+
for id in image_ids:
|
100 |
+
images = load_images(image_ids)
|
101 |
+
|
102 |
+
st.header('Output')
|
103 |
+
|
104 |
+
models = {}
|
105 |
+
feature_extractors = {}
|
106 |
+
|
107 |
+
for i, model_name in enumerate(selected_models):
|
108 |
+
models[model_name], feature_extractors[model_name] = load_model(model_name)
|
109 |
+
|
110 |
+
|
111 |
+
# DISPLAY ----------------------------------
|
112 |
+
header_cols = st.columns([1, 1] + [2]*len(selected_models))
|
113 |
+
header_cols[0].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Image ID</b></div>', unsafe_allow_html=True)
|
114 |
+
header_cols[1].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Original Image</b></div>', unsafe_allow_html=True)
|
115 |
+
for i, model_name in enumerate(selected_models):
|
116 |
+
header_cols[i + 2].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>{model_name}</b></div>', unsafe_allow_html=True)
|
117 |
+
|
118 |
+
grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
|
119 |
+
# grids[0][0].write('Image ID')
|
120 |
+
# grids[0][1].write('Original image')
|
121 |
+
|
122 |
+
# for i, model_name in enumerate(selected_models):
|
123 |
+
# models[model_name], feature_extractors[model_name] = load_model(model_name)
|
124 |
+
|
125 |
+
|
126 |
+
@st.cache(allow_output_mutation=True)
|
127 |
+
def generate_images(image, model_name):
|
128 |
+
return generate_smoothgrad_mask(
|
129 |
+
image, model_name,
|
130 |
+
models[model_name], feature_extractors[model_name], num_samples=10)
|
131 |
+
|
132 |
+
with _lock:
|
133 |
+
for j, (image_id, image_dict) in enumerate(zip(image_ids, images)):
|
134 |
+
grids[j][0].write(f'{image_id}. {image_dict["label"]}')
|
135 |
+
image = image_dict['image']
|
136 |
+
ori_image = ShowImage(np.asarray(image))
|
137 |
+
grids[j][1].image(ori_image)
|
138 |
+
|
139 |
+
for i, model_name in enumerate(selected_models):
|
140 |
+
# ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
|
141 |
+
# model_name, models[model_name], feature_extractors[model_name], num_samples=10)
|
142 |
+
heatmap_image, masked_image = generate_images(image, model_name)
|
143 |
+
# grids[j][1].image(ori_image)
|
144 |
+
grids[j][i*2+2].image(heatmap_image)
|
145 |
+
grids[j][i*2+3].image(masked_image)
|
pages/3_ImageNet1k.py
CHANGED
@@ -1,3 +1,48 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from backend.utils import load_dataset, use_container_width_percentage
|
5 |
+
|
6 |
+
st.title('ImageNet-1k')
|
7 |
+
st.markdown('This page shows the summary of 50,000 images in the validation set of [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)')
|
8 |
+
|
9 |
+
# SCREEN_WIDTH, SCREEN_HEIGHT = 2560, 1664
|
10 |
+
|
11 |
+
with st.spinner("Loading dataset..."):
|
12 |
+
dataset_dict = {}
|
13 |
+
for data_index in range(5):
|
14 |
+
dataset_dict[data_index] = load_dataset(data_index)
|
15 |
+
|
16 |
+
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
17 |
+
|
18 |
+
class_labels = imagenet_df.ClassLabel.unique().tolist()
|
19 |
+
class_labels.sort()
|
20 |
+
selected_classes = st.multiselect('Class filter: ', options=['All'] + class_labels)
|
21 |
+
if not ('All' in selected_classes or len(selected_classes) == 0):
|
22 |
+
imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
|
23 |
+
# st.write(class_labels)
|
24 |
+
|
25 |
+
col1, col2 = st.columns([2, 1])
|
26 |
+
with col1:
|
27 |
+
st.dataframe(imagenet_df)
|
28 |
+
use_container_width_percentage(100)
|
29 |
+
|
30 |
+
with col2:
|
31 |
+
st.text_area('Type anything here to copy later :)')
|
32 |
+
image = None
|
33 |
+
with st.form("display image"):
|
34 |
+
img_index = st.text_input('Image ID to display')
|
35 |
+
try:
|
36 |
+
img_index = int(img_index)
|
37 |
+
except:
|
38 |
+
pass
|
39 |
+
|
40 |
+
submitted = st.form_submit_button('Display this image')
|
41 |
+
if submitted:
|
42 |
+
image = dataset_dict[img_index//10_000][img_index%10_000]['image']
|
43 |
+
class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
|
44 |
+
class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
|
45 |
+
if image != None:
|
46 |
+
st.image(image)
|
47 |
+
st.write('**Class label:** ', class_label)
|
48 |
+
st.write('\n**Class id:** ', str(class_id))
|
requirements.txt
CHANGED
@@ -1,3 +1,17 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
captum==0.5.0
|
2 |
+
deta==1.1.0
|
3 |
+
graphviz==0.20.1
|
4 |
+
Markdown==3.4.1
|
5 |
+
matplotlib==3.6.2
|
6 |
+
numpy==1.22.3
|
7 |
+
opencv_python_headless==4.6.0.66
|
8 |
+
pandas==1.5.2
|
9 |
+
Pillow==9.3.0
|
10 |
+
plotly==5.11.0
|
11 |
+
scipy==1.9.3
|
12 |
+
setuptools==65.5.0
|
13 |
+
streamlit==1.15.2
|
14 |
+
torch==1.10.1
|
15 |
+
torchvision==0.11.2
|
16 |
+
tqdm==4.64.1
|
17 |
+
transformers==4.25.1
|