Spaces:
Runtime error
Runtime error
added disentanglement also for vase art
Browse files- backend/disentangle_concepts.py +15 -4
- data/vase_annotated_files/seeds0000-20000.pkl +3 -0
- data/vase_annotated_files/sim_Shape Name_seeds0000-20000.csv +3 -0
- data/vase_model_files/network-snapshot-003800.pkl +3 -0
- pages/{1_Disentanglement.py → 1_Omniart_Disentanglement.py} +0 -0
- pages/3_Oxford_Vases_Disentanglement.py +171 -0
- pages/3_todo.py +0 -124
backend/disentangle_concepts.py
CHANGED
@@ -28,10 +28,21 @@ def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=
|
|
28 |
else:
|
29 |
col = 'w_vectors'
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
X = X.reshape((2*samples, 512))
|
36 |
y = np.array([1]*samples + [0]*samples)
|
37 |
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
|
|
28 |
else:
|
29 |
col = 'w_vectors'
|
30 |
|
31 |
+
if type(type_bin) == str or len(type_bin) == 1:
|
32 |
+
abstracts = np.array([float(ann) for ann in df[type_bin]])
|
33 |
+
abstract_idxs = list(np.argsort(abstracts))[:samples]
|
34 |
+
repr_idxs = list(np.argsort(abstracts))[-samples:]
|
35 |
+
X = np.array([annotations[col][i] for i in abstract_idxs+repr_idxs])
|
36 |
+
elif len(type_bin) == 2:
|
37 |
+
print('Using two concepts for separation space')
|
38 |
+
first_concept = np.array([float(ann) for ann in df[type_bin[0]]])
|
39 |
+
second_concept = np.array([float(ann) for ann in df[type_bin[1]]])
|
40 |
+
first_idxs = list(np.argsort(first_concept))[:samples]
|
41 |
+
second_idxs = list(np.argsort(second_concept))[:samples]
|
42 |
+
X = np.array([annotations[col][i] for i in first_idxs+second_idxs])
|
43 |
+
else:
|
44 |
+
print('Error: type_bin must be either a string or a list of strings of len 2')
|
45 |
+
return
|
46 |
X = X.reshape((2*samples, 512))
|
47 |
y = np.array([1]*samples + [0]*samples)
|
48 |
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
data/vase_annotated_files/seeds0000-20000.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e790910bf45c0d5a84e74c9011b88012f59d0fc27b19987c890b891c57ab739c
|
3 |
+
size 125913423
|
data/vase_annotated_files/sim_Shape Name_seeds0000-20000.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e258361e0db7c208ae67654c08ed5b900df10980e82e84bcddd3de89428f679a
|
3 |
+
size 30853761
|
data/vase_model_files/network-snapshot-003800.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42be0a24e7021dc66a9353c3a904494bb8e64b62e00e535ad3b03ad18238b0d2
|
3 |
+
size 357349976
|
pages/{1_Disentanglement.py → 1_Omniart_Disentanglement.py}
RENAMED
File without changes
|
pages/3_Oxford_Vases_Disentanglement.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from matplotlib.backends.backend_agg import RendererAgg
|
9 |
+
|
10 |
+
from backend.disentangle_concepts import *
|
11 |
+
import torch_utils
|
12 |
+
import dnnlib
|
13 |
+
import legacy
|
14 |
+
|
15 |
+
_lock = RendererAgg.lock
|
16 |
+
|
17 |
+
|
18 |
+
st.set_page_config(layout='wide')
|
19 |
+
BACKGROUND_COLOR = '#bcd0e7'
|
20 |
+
SECONDARY_COLOR = '#bce7db'
|
21 |
+
|
22 |
+
|
23 |
+
st.title('Disentanglement studies on the Oxford Vases Dataset')
|
24 |
+
st.markdown(
|
25 |
+
"""
|
26 |
+
This is a demo of the Disentanglement studies on the [Oxford Vases Dataset](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/).
|
27 |
+
""",
|
28 |
+
unsafe_allow_html=False,)
|
29 |
+
|
30 |
+
annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
|
31 |
+
with open(annotations_file, 'rb') as f:
|
32 |
+
annotations = pickle.load(f)
|
33 |
+
|
34 |
+
ann_df = pd.read_csv('./data/vase_annotated_files/sim_Shape Name_seeds0000-20000.csv')
|
35 |
+
labels = ann_df.columns
|
36 |
+
|
37 |
+
if 'image_id' not in st.session_state:
|
38 |
+
st.session_state.image_id = 0
|
39 |
+
if 'concept_ids' not in st.session_state:
|
40 |
+
st.session_state.concept_ids =['AMPHORA']
|
41 |
+
if 'space_id' not in st.session_state:
|
42 |
+
st.session_state.space_id = 'W'
|
43 |
+
|
44 |
+
# def on_change_random_input():
|
45 |
+
# st.session_state.image_id = st.session_state.image_id
|
46 |
+
|
47 |
+
# ----------------------------- INPUT ----------------------------------
|
48 |
+
st.header('Input')
|
49 |
+
input_col_1, input_col_2, input_col_3 = st.columns(3)
|
50 |
+
# --------------------------- INPUT column 1 ---------------------------
|
51 |
+
with input_col_1:
|
52 |
+
with st.form('text_form'):
|
53 |
+
|
54 |
+
# image_id = st.number_input('Image ID: ', format='%d', step=1)
|
55 |
+
st.write('**Choose two options to disentangle**')
|
56 |
+
concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=['AMPHORA', 'CHALICE'])
|
57 |
+
|
58 |
+
st.write('**Choose a latent space to disentangle**')
|
59 |
+
space_id = st.selectbox('Space:', tuple(['Z', 'W']))
|
60 |
+
|
61 |
+
choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
|
62 |
+
|
63 |
+
if choose_text_button:
|
64 |
+
concept_ids = list(concept_ids)
|
65 |
+
st.session_state.concept_ids = concept_ids
|
66 |
+
space_id = str(space_id)
|
67 |
+
st.session_state.space_id = space_id
|
68 |
+
# st.write(image_id, st.session_state.image_id)
|
69 |
+
|
70 |
+
# ---------------------------- SET UP OUTPUT ------------------------------
|
71 |
+
epsilon_container = st.empty()
|
72 |
+
st.header('Output')
|
73 |
+
st.subheader('Concept vector')
|
74 |
+
|
75 |
+
# perform attack container
|
76 |
+
# header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
77 |
+
# output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
78 |
+
header_col_1, header_col_2 = st.columns([5,1])
|
79 |
+
output_col_1, output_col_2 = st.columns([5,1])
|
80 |
+
|
81 |
+
st.subheader('Derivations along the concept vector')
|
82 |
+
|
83 |
+
# prediction error container
|
84 |
+
error_container = st.empty()
|
85 |
+
smoothgrad_header_container = st.empty()
|
86 |
+
|
87 |
+
# smoothgrad container
|
88 |
+
smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
|
89 |
+
smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
|
90 |
+
|
91 |
+
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
92 |
+
with output_col_1:
|
93 |
+
separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id)
|
94 |
+
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
95 |
+
st.write('Concept vector', separation_vector)
|
96 |
+
header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
|
97 |
+
|
98 |
+
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
99 |
+
with input_col_2:
|
100 |
+
with st.form('image_form'):
|
101 |
+
|
102 |
+
# image_id = st.number_input('Image ID: ', format='%d', step=1)
|
103 |
+
st.write('**Choose or generate a random image to test the disentanglement**')
|
104 |
+
chosen_image_id_input = st.empty()
|
105 |
+
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
106 |
+
|
107 |
+
choose_image_button = st.form_submit_button('Choose the defined image')
|
108 |
+
random_id = st.form_submit_button('Generate a random image')
|
109 |
+
|
110 |
+
if random_id:
|
111 |
+
image_id = random.randint(0, 50000)
|
112 |
+
st.session_state.image_id = image_id
|
113 |
+
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
114 |
+
|
115 |
+
if choose_image_button:
|
116 |
+
image_id = int(image_id)
|
117 |
+
st.session_state.image_id = int(image_id)
|
118 |
+
# st.write(image_id, st.session_state.image_id)
|
119 |
+
|
120 |
+
with input_col_3:
|
121 |
+
with st.form('Variate along the disentangled concept'):
|
122 |
+
st.write('**Set range of change**')
|
123 |
+
chosen_epsilon_input = st.empty()
|
124 |
+
epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
|
125 |
+
epsilon_button = st.form_submit_button('Choose the defined lambda')
|
126 |
+
st.write('**Select hierarchical levels to manipulate**')
|
127 |
+
layers = st.multiselect('Layers:', tuple(range(14)))
|
128 |
+
if len(layers) == 0:
|
129 |
+
layers = None
|
130 |
+
print(layers)
|
131 |
+
layers_button = st.form_submit_button('Choose the defined layers')
|
132 |
+
|
133 |
+
|
134 |
+
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
135 |
+
|
136 |
+
#model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
|
137 |
+
with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
|
138 |
+
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
|
139 |
+
|
140 |
+
if st.session_state.space_id == 'Z':
|
141 |
+
original_image_vec = annotations['z_vectors'][st.session_state.image_id]
|
142 |
+
else:
|
143 |
+
original_image_vec = annotations['w_vectors'][st.session_state.image_id]
|
144 |
+
|
145 |
+
img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
|
146 |
+
# input_image = original_image_dict['image']
|
147 |
+
# input_label = original_image_dict['label']
|
148 |
+
# input_id = original_image_dict['id']
|
149 |
+
|
150 |
+
with smoothgrad_col_3:
|
151 |
+
st.image(img)
|
152 |
+
smooth_head_3.write(f'Base image')
|
153 |
+
|
154 |
+
|
155 |
+
images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
|
156 |
+
|
157 |
+
with smoothgrad_col_1:
|
158 |
+
st.image(images[0])
|
159 |
+
smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
|
160 |
+
|
161 |
+
with smoothgrad_col_2:
|
162 |
+
st.image(images[1])
|
163 |
+
smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
|
164 |
+
|
165 |
+
with smoothgrad_col_4:
|
166 |
+
st.image(images[3])
|
167 |
+
smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
|
168 |
+
|
169 |
+
with smoothgrad_col_5:
|
170 |
+
st.image(images[4])
|
171 |
+
smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
|
pages/3_todo.py
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import pandas as pd
|
3 |
-
import numpy as np
|
4 |
-
import random
|
5 |
-
from backend.utils import make_grid, load_dataset, load_model, load_images
|
6 |
-
|
7 |
-
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
|
8 |
-
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
9 |
-
import torch
|
10 |
-
|
11 |
-
from matplotlib.backends.backend_agg import RendererAgg
|
12 |
-
_lock = RendererAgg.lock
|
13 |
-
|
14 |
-
st.set_page_config(layout='wide')
|
15 |
-
BACKGROUND_COLOR = '#bcd0e7'
|
16 |
-
|
17 |
-
|
18 |
-
st.title('Feature attribution visualization with SmoothGrad')
|
19 |
-
st.write("""> **Which features are responsible for the current prediction of ConvNeXt?**
|
20 |
-
|
21 |
-
In machine learning, it is helpful to identify the significant features of the input (e.g., pixels for images) that affect the model's prediction.
|
22 |
-
If the model makes an incorrect prediction, we might want to determine which features contributed to the mistake.
|
23 |
-
To do this, we can generate a feature importance mask, which is a grayscale image with the same size as the original image.
|
24 |
-
The brightness of each pixel in the mask represents the importance of that feature to the model's prediction.
|
25 |
-
|
26 |
-
There are various methods to calculate an image sensitivity mask for a specific prediction.
|
27 |
-
One simple way is to use the gradient of a class prediction neuron concerning the input pixels, indicating how the prediction is affected by small pixel changes.
|
28 |
-
However, this method usually produces a noisy mask.
|
29 |
-
To reduce the noise, the SmoothGrad technique as described in [SmoothGrad: Removing noise by adding noise](https://arxiv.org/abs/1706.03825) by Daniel _et al_ is used,
|
30 |
-
which adds Gaussian noise to multiple copies of the image and averages the resulting gradients.
|
31 |
-
""")
|
32 |
-
|
33 |
-
instruction_text = """Users need to input the model(s), type of image set and image set setting to use this functionality.
|
34 |
-
1. Choose model: Users can choose one or more models for comparison.
|
35 |
-
There are 3 models supported: [ConvNeXt](https://huggingface.co/facebook/convnext-tiny-224),
|
36 |
-
[ResNet](https://huggingface.co/microsoft/resnet-50) and [MobileNet](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/).
|
37 |
-
These 3 models have similar number of parameters.
|
38 |
-
\n2. Choose type of Image set: There are 2 types of Image set. They are _User-defined set_ and _Random set_.
|
39 |
-
\n3. Image set setting: If users choose _User-defined set_ in Image set,
|
40 |
-
users need to enter a list of image IDs separated by commas (,). For example, `0,1,4,7` is a valid input.
|
41 |
-
Check the page [ImageNet1k](/ImageNet1k) to see all the Image IDs.
|
42 |
-
If users choose _Random set_ in Image set, users just need to choose the number of random images to display here.
|
43 |
-
"""
|
44 |
-
with st.expander("See more instruction", expanded=False):
|
45 |
-
st.write(instruction_text)
|
46 |
-
|
47 |
-
|
48 |
-
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
49 |
-
|
50 |
-
# --------------------------- LOAD function -----------------------------
|
51 |
-
|
52 |
-
|
53 |
-
images = []
|
54 |
-
image_ids = []
|
55 |
-
# INPUT ------------------------------
|
56 |
-
st.header('Input')
|
57 |
-
with st.form('smooth_grad_form'):
|
58 |
-
st.markdown('**Model and Input Setting**')
|
59 |
-
selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
|
60 |
-
selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
|
61 |
-
|
62 |
-
summit_button = st.form_submit_button('Set')
|
63 |
-
if summit_button:
|
64 |
-
setting_container = st.container()
|
65 |
-
# for id in image_ids:
|
66 |
-
# images = load_images(image_ids)
|
67 |
-
|
68 |
-
with st.form('2nd_form'):
|
69 |
-
st.markdown('**Image set setting**')
|
70 |
-
if selected_image_set == 'Random set':
|
71 |
-
no_images = st.slider('Number of images', 1, 50, value=10)
|
72 |
-
image_ids = random.sample(list(range(50_000)), k=no_images)
|
73 |
-
else:
|
74 |
-
text = st.text_area('Specific Image IDs', value='0')
|
75 |
-
image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
|
76 |
-
|
77 |
-
run_button = st.form_submit_button('Display output')
|
78 |
-
if run_button:
|
79 |
-
for id in image_ids:
|
80 |
-
images = load_images(image_ids)
|
81 |
-
|
82 |
-
st.header('Output')
|
83 |
-
|
84 |
-
models = {}
|
85 |
-
feature_extractors = {}
|
86 |
-
|
87 |
-
for i, model_name in enumerate(selected_models):
|
88 |
-
models[model_name], feature_extractors[model_name] = load_model(model_name)
|
89 |
-
|
90 |
-
|
91 |
-
# DISPLAY ----------------------------------
|
92 |
-
if run_button:
|
93 |
-
header_cols = st.columns([1, 1] + [2]*len(selected_models))
|
94 |
-
header_cols[0].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Image ID</b></div>', unsafe_allow_html=True)
|
95 |
-
header_cols[1].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Original Image</b></div>', unsafe_allow_html=True)
|
96 |
-
for i, model_name in enumerate(selected_models):
|
97 |
-
header_cols[i + 2].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>{model_name}</b></div>', unsafe_allow_html=True)
|
98 |
-
|
99 |
-
grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
|
100 |
-
|
101 |
-
|
102 |
-
@st.cache(allow_output_mutation=True)
|
103 |
-
# @st.cache_data
|
104 |
-
def generate_images(image_id, model_name):
|
105 |
-
j = image_ids.index(image_id)
|
106 |
-
image = images[j]['image']
|
107 |
-
return generate_smoothgrad_mask(
|
108 |
-
image, model_name,
|
109 |
-
models[model_name], feature_extractors[model_name], num_samples=10)
|
110 |
-
|
111 |
-
with _lock:
|
112 |
-
for j, (image_id, image_dict) in enumerate(zip(image_ids, images)):
|
113 |
-
grids[j][0].write(f'{image_id}. {image_dict["label"]}')
|
114 |
-
image = image_dict['image']
|
115 |
-
ori_image = ShowImage(np.asarray(image))
|
116 |
-
grids[j][1].image(ori_image)
|
117 |
-
|
118 |
-
for i, model_name in enumerate(selected_models):
|
119 |
-
# ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
|
120 |
-
# model_name, models[model_name], feature_extractors[model_name], num_samples=10)
|
121 |
-
heatmap_image, masked_image = generate_images(image_id, model_name)
|
122 |
-
# grids[j][1].image(ori_image)
|
123 |
-
grids[j][i*2+2].image(heatmap_image)
|
124 |
-
grids[j][i*2+3].image(masked_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|