Spaces:
Build error
Build error
taquynhnga
commited on
Commit
β’
0c1e42b
1
Parent(s):
8a287fa
added adversarial attacks & change to streamlit 1.19.0
Browse files- .vscode/settings.json +5 -3
- README.md +1 -1
- backend/adversarial_attack.py +99 -0
- backend/load_file.py +8 -4
- backend/maximally_activating_patches.py +4 -2
- backend/smooth_grad.py +8 -6
- backend/utils.py +43 -1
- frontend/images/equal-sign.png +0 -0
- frontend/images/minus-sign-2.png +0 -0
- frontend/images/minus-sign-3.png +0 -0
- frontend/images/minus-sign-4.png +0 -0
- frontend/images/minus-sign-5.png +0 -0
- frontend/images/minus-sign.png +0 -0
- frontend/images/plus-sign-2.png +0 -0
- frontend/images/plus-sign.png +0 -0
- load_file.py +0 -37
- pages/1_Maximally_activating_patches.py +2 -2
- pages/2_SmoothGrad.py +35 -51
- pages/3_Adversarial_attack.py +184 -0
- pages/{3_ImageNet1k.py β 4_ImageNet1k.py} +0 -0
- requirements.txt +4 -5
.vscode/settings.json
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.analysis.extraPaths": [
|
3 |
+
"./Visual-Explanation-Methods-PyTorch"
|
4 |
+
]
|
5 |
+
}
|
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: CNNs Interpretation Visualization
|
3 |
emoji: π‘
|
4 |
-
colorFrom:
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
|
|
1 |
---
|
2 |
title: CNNs Interpretation Visualization
|
3 |
emoji: π‘
|
4 |
+
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
backend/adversarial_attack.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from matplotlib import pylab as P
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import TensorDataset
|
9 |
+
from torchvision import transforms
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
13 |
+
|
14 |
+
from torchvex.base import ExplanationMethod
|
15 |
+
from torchvex.utils.normalization import clamp_quantile
|
16 |
+
|
17 |
+
from backend.utils import load_image, load_model
|
18 |
+
from backend.smooth_grad import generate_smoothgrad_mask
|
19 |
+
|
20 |
+
import streamlit as st
|
21 |
+
|
22 |
+
IMAGENET_DEFAULT_MEAN = np.asarray(IMAGENET_DEFAULT_MEAN).reshape([1,3,1,1])
|
23 |
+
IMAGENET_DEFAULT_STD = np.asarray(IMAGENET_DEFAULT_STD).reshape([1,3,1,1])
|
24 |
+
|
25 |
+
def deprocess_image(image_inputs):
|
26 |
+
return (image_inputs * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN) * 255
|
27 |
+
|
28 |
+
|
29 |
+
def feed_forward(input_image):
|
30 |
+
model, feature_extractor = load_model('ConvNeXt')
|
31 |
+
inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
|
32 |
+
logits = model(inputs).logits
|
33 |
+
prediction_prob = F.softmax(logits, dim=-1).max() # prediction probability
|
34 |
+
# prediction class id, start from 1 to 1000 so it needs to +1 in the end
|
35 |
+
prediction_class = logits.argmax(-1).item()
|
36 |
+
prediction_label = model.config.id2label[prediction_class] # prediction class label
|
37 |
+
return prediction_prob, prediction_class, prediction_label
|
38 |
+
|
39 |
+
# FGSM attack code
|
40 |
+
def fgsm_attack(image, epsilon, data_grad):
|
41 |
+
# Collect the element-wise sign of the data gradient and normalize it
|
42 |
+
sign_data_grad = torch.gt(data_grad, 0).type(torch.FloatTensor) * 2.0 - 1.0
|
43 |
+
perturbed_image = image + epsilon*sign_data_grad
|
44 |
+
return perturbed_image
|
45 |
+
|
46 |
+
# perform attack on the model
|
47 |
+
def perform_attack(input_image, target, epsilon):
|
48 |
+
model, feature_extractor = load_model("ConvNeXt")
|
49 |
+
# preprocess input image
|
50 |
+
inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
|
51 |
+
inputs.requires_grad = True
|
52 |
+
|
53 |
+
# predict
|
54 |
+
logits = model(inputs).logits
|
55 |
+
prediction_prob = F.softmax(logits, dim=-1).max()
|
56 |
+
prediction_class = logits.argmax(-1).item()
|
57 |
+
prediction_label = model.config.id2label[prediction_class]
|
58 |
+
|
59 |
+
# Calculate the loss
|
60 |
+
loss = F.nll_loss(logits, torch.tensor([target]))
|
61 |
+
|
62 |
+
# Zero all existing gradients
|
63 |
+
model.zero_grad()
|
64 |
+
|
65 |
+
# Calculate gradients of model in backward pass
|
66 |
+
loss.backward()
|
67 |
+
|
68 |
+
# Collect datagrad
|
69 |
+
data_grad = inputs.grad.data
|
70 |
+
|
71 |
+
# Call FGSM Attack
|
72 |
+
perturbed_data = fgsm_attack(inputs, epsilon, data_grad)
|
73 |
+
|
74 |
+
# Re-classify the perturbed image
|
75 |
+
new_prediction = model(perturbed_data).logits
|
76 |
+
new_pred_prob = F.softmax(new_prediction, dim=-1).max()
|
77 |
+
new_pred_class = new_prediction.argmax(-1).item()
|
78 |
+
new_pred_label = model.config.id2label[new_pred_class]
|
79 |
+
|
80 |
+
return perturbed_data, new_pred_prob.item(), new_pred_class, new_pred_label
|
81 |
+
|
82 |
+
|
83 |
+
def find_smallest_epsilon(input_image, target):
|
84 |
+
epsilons = [i*0.001 for i in range(1000)]
|
85 |
+
|
86 |
+
for epsilon in epsilons:
|
87 |
+
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, target, epsilon)
|
88 |
+
if new_id != target:
|
89 |
+
return perturbed_data, new_prob, new_id, new_label, epsilon
|
90 |
+
return None
|
91 |
+
|
92 |
+
@st.cache_data
|
93 |
+
def generate_images(image_id, epsilon=0):
|
94 |
+
model, feature_extractor = load_model("ConvNeXt")
|
95 |
+
original_image_dict = load_image(image_id)
|
96 |
+
image = original_image_dict['image']
|
97 |
+
return generate_smoothgrad_mask(
|
98 |
+
image, 'ConvNeXt',
|
99 |
+
model, feature_extractor, num_samples=10, return_mask=True)
|
backend/load_file.py
CHANGED
@@ -19,7 +19,11 @@ def load_json(filename):
|
|
19 |
loaded_dict = json.loads(read_file.read())
|
20 |
loaded_dict = OrderedDict(loaded_dict)
|
21 |
for k, v in loaded_dict.items():
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
return loaded_dict
|
24 |
|
25 |
class NumpyEncoder(json.JSONEncoder):
|
@@ -32,6 +36,6 @@ class NumpyEncoder(json.JSONEncoder):
|
|
32 |
# save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
|
33 |
# save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
|
34 |
|
35 |
-
file = load_json('data/layer_infos/convnext_layer_infos.json')
|
36 |
-
print(type(file))
|
37 |
-
print(type(file['embeddings.patch_embeddings']))
|
|
|
19 |
loaded_dict = json.loads(read_file.read())
|
20 |
loaded_dict = OrderedDict(loaded_dict)
|
21 |
for k, v in loaded_dict.items():
|
22 |
+
if type(v) == list:
|
23 |
+
loaded_dict[k] = np.asarray(v)
|
24 |
+
else:
|
25 |
+
for k_, v_ in v.items():
|
26 |
+
v[k_] = np.asarray(v_)
|
27 |
return loaded_dict
|
28 |
|
29 |
class NumpyEncoder(json.JSONEncoder):
|
|
|
36 |
# save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
|
37 |
# save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
|
38 |
|
39 |
+
# file = load_json('data/layer_infos/convnext_layer_infos.json')
|
40 |
+
# print(type(file))
|
41 |
+
# print(type(file['embeddings.patch_embeddings']))
|
backend/maximally_activating_patches.py
CHANGED
@@ -4,12 +4,14 @@ import streamlit as st
|
|
4 |
from backend.load_file import load_json
|
5 |
|
6 |
|
7 |
-
@st.cache(allow_output_mutation=True)
|
|
|
8 |
def load_activation(filename):
|
9 |
activation = load_json(filename)
|
10 |
return activation
|
11 |
|
12 |
-
@st.cache(allow_output_mutation=True)
|
|
|
13 |
def load_dataset(data_index):
|
14 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
15 |
dataset = pickle.load(file)
|
|
|
4 |
from backend.load_file import load_json
|
5 |
|
6 |
|
7 |
+
# @st.cache(allow_output_mutation=True)
|
8 |
+
st.cache_data
|
9 |
def load_activation(filename):
|
10 |
activation = load_json(filename)
|
11 |
return activation
|
12 |
|
13 |
+
# @st.cache(allow_output_mutation=True)
|
14 |
+
@st.cache_data
|
15 |
def load_dataset(data_index):
|
16 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
17 |
dataset = pickle.load(file)
|
backend/smooth_grad.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
import sys
|
2 |
-
import os
|
3 |
import PIL
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
@@ -10,8 +8,8 @@ import torch
|
|
10 |
from torch.utils.data import TensorDataset
|
11 |
from torchvision import transforms
|
12 |
|
13 |
-
dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
|
14 |
-
sys.path.append(dirpath_to_modules)
|
15 |
|
16 |
from torchvex.base import ExplanationMethod
|
17 |
from torchvex.utils.normalization import clamp_quantile
|
@@ -212,7 +210,7 @@ def fig2img(fig):
|
|
212 |
img = Image.open(buf)
|
213 |
return img
|
214 |
|
215 |
-
def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25):
|
216 |
inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
|
217 |
|
218 |
smoothgrad_gen = SmoothGradient(
|
@@ -230,4 +228,8 @@ def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=No
|
|
230 |
# ori_image = ShowImage(image)
|
231 |
heat_map_image = ShowHeatMap(smoothgrad_mask)
|
232 |
masked_image = ShowMaskedImage(smoothgrad_mask, image)
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import PIL
|
2 |
from PIL import Image
|
3 |
import numpy as np
|
|
|
8 |
from torch.utils.data import TensorDataset
|
9 |
from torchvision import transforms
|
10 |
|
11 |
+
# dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
|
12 |
+
# sys.path.append(dirpath_to_modules)
|
13 |
|
14 |
from torchvex.base import ExplanationMethod
|
15 |
from torchvex.utils.normalization import clamp_quantile
|
|
|
210 |
img = Image.open(buf)
|
211 |
return img
|
212 |
|
213 |
+
def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25, return_mask=False):
|
214 |
inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
|
215 |
|
216 |
smoothgrad_gen = SmoothGradient(
|
|
|
228 |
# ori_image = ShowImage(image)
|
229 |
heat_map_image = ShowHeatMap(smoothgrad_mask)
|
230 |
masked_image = ShowMaskedImage(smoothgrad_mask, image)
|
231 |
+
|
232 |
+
if return_mask:
|
233 |
+
return heat_map_image, masked_image, smoothgrad_mask
|
234 |
+
else:
|
235 |
+
return heat_map_image, masked_image
|
backend/utils.py
CHANGED
@@ -14,12 +14,17 @@ from plotly import express as px
|
|
14 |
from plotly.subplots import make_subplots
|
15 |
from tqdm import trange
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
def load_dataset(data_index):
|
19 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
20 |
dataset = pickle.load(file)
|
21 |
return dataset
|
22 |
|
|
|
23 |
def load_dataset_dict():
|
24 |
dataset_dict = {}
|
25 |
progress_empty = st.empty()
|
@@ -33,6 +38,43 @@ def load_dataset_dict():
|
|
33 |
text_empty.empty()
|
34 |
return dataset_dict
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def make_grid(cols=None,rows=None):
|
37 |
grid = [0]*rows
|
38 |
for i in range(rows):
|
|
|
14 |
from plotly.subplots import make_subplots
|
15 |
from tqdm import trange
|
16 |
|
17 |
+
import torch
|
18 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
19 |
+
|
20 |
+
# @st.cache(allow_output_mutation=True)
|
21 |
+
@st.cache_resource
|
22 |
def load_dataset(data_index):
|
23 |
with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
|
24 |
dataset = pickle.load(file)
|
25 |
return dataset
|
26 |
|
27 |
+
@st.cache_resource
|
28 |
def load_dataset_dict():
|
29 |
dataset_dict = {}
|
30 |
progress_empty = st.empty()
|
|
|
38 |
text_empty.empty()
|
39 |
return dataset_dict
|
40 |
|
41 |
+
|
42 |
+
@st.cache_data
|
43 |
+
def load_image(image_id):
|
44 |
+
dataset = load_dataset(image_id//10000)
|
45 |
+
image = dataset[image_id%10000]
|
46 |
+
return image
|
47 |
+
|
48 |
+
@st.cache_data
|
49 |
+
def load_images(image_ids):
|
50 |
+
images = []
|
51 |
+
for image_id in image_ids:
|
52 |
+
image = load_image(image_id)
|
53 |
+
images.append(image)
|
54 |
+
return images
|
55 |
+
|
56 |
+
|
57 |
+
# @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
|
58 |
+
@st.cache_resource
|
59 |
+
def load_model(model_name):
|
60 |
+
with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
61 |
+
if model_name == 'ResNet':
|
62 |
+
model_file_path = 'microsoft/resnet-50'
|
63 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
64 |
+
model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
65 |
+
model.eval()
|
66 |
+
elif model_name == 'ConvNeXt':
|
67 |
+
model_file_path = 'facebook/convnext-tiny-224'
|
68 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
69 |
+
model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
70 |
+
model.eval()
|
71 |
+
else:
|
72 |
+
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
|
73 |
+
model.eval()
|
74 |
+
feature_extractor = None
|
75 |
+
return model, feature_extractor
|
76 |
+
|
77 |
+
|
78 |
def make_grid(cols=None,rows=None):
|
79 |
grid = [0]*rows
|
80 |
for i in range(rows):
|
frontend/images/equal-sign.png
ADDED
frontend/images/minus-sign-2.png
ADDED
frontend/images/minus-sign-3.png
ADDED
frontend/images/minus-sign-4.png
ADDED
frontend/images/minus-sign-5.png
ADDED
frontend/images/minus-sign.png
ADDED
frontend/images/plus-sign-2.png
ADDED
frontend/images/plus-sign.png
ADDED
load_file.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import pickle
|
3 |
-
import numpy as np
|
4 |
-
from collections import OrderedDict
|
5 |
-
|
6 |
-
def load_pickle(filename):
|
7 |
-
with open(filename, 'rb') as file:
|
8 |
-
data = pickle.load(file)
|
9 |
-
return data
|
10 |
-
|
11 |
-
def save_pickle_to_json(filename):
|
12 |
-
ordered_dict = load_pickle(filename)
|
13 |
-
json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
|
14 |
-
with open(filename.replace('.pkl', '.json'), 'w') as f:
|
15 |
-
f.write(json_obj)
|
16 |
-
|
17 |
-
def load_json(filename):
|
18 |
-
with open(filename, 'r') as read_file:
|
19 |
-
loaded_dict = json.loads(read_file.read())
|
20 |
-
loaded_dict = OrderedDict(loaded_dict)
|
21 |
-
for k, v in loaded_dict.items():
|
22 |
-
loaded_dict[k] = np.asarray(v)
|
23 |
-
return loaded_dict
|
24 |
-
|
25 |
-
class NumpyEncoder(json.JSONEncoder):
|
26 |
-
def default(self, obj):
|
27 |
-
if isinstance(obj, np.ndarray):
|
28 |
-
return obj.tolist()
|
29 |
-
return json.JSONEncoder.default(self, obj)
|
30 |
-
|
31 |
-
# save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
|
32 |
-
# save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
|
33 |
-
# save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
|
34 |
-
|
35 |
-
file = load_json('data/layer_infos/convnext_layer_infos.json')
|
36 |
-
print(type(file))
|
37 |
-
print(type(file['embeddings.patch_embeddings']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/1_Maximally_activating_patches.py
CHANGED
@@ -28,7 +28,7 @@ def load_dot_to_graph(filename):
|
|
28 |
return graph, dot
|
29 |
|
30 |
st.title('Maximally activating image patches')
|
31 |
-
st.write('Visualize image patches that maximize the activation of layers in
|
32 |
|
33 |
# st.header('ConvNeXt')
|
34 |
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
|
@@ -130,7 +130,7 @@ if nodes != None:
|
|
130 |
subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
|
131 |
else:
|
132 |
top_margin = 0
|
133 |
-
fig = make_subplots(rows=1, cols=num_cols)
|
134 |
for col in range(1, num_cols+1):
|
135 |
k, c = col-1, row-1
|
136 |
img_index = int(top_k_coor_max_[k, c, 3])
|
|
|
28 |
return graph, dot
|
29 |
|
30 |
st.title('Maximally activating image patches')
|
31 |
+
st.write('Visualize image patches that maximize the activation of layers in ConvNeXt model')
|
32 |
|
33 |
# st.header('ConvNeXt')
|
34 |
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
|
|
|
130 |
subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
|
131 |
else:
|
132 |
top_margin = 0
|
133 |
+
fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
|
134 |
for col in range(1, num_cols+1):
|
135 |
k, c = col-1, row-1
|
136 |
img_index = int(top_k_coor_max_[k, c, 3])
|
pages/2_SmoothGrad.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import random
|
5 |
-
from backend.utils import make_grid, load_dataset
|
6 |
|
7 |
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
|
8 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
@@ -22,32 +22,34 @@ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
|
22 |
|
23 |
# --------------------------- LOAD function -----------------------------
|
24 |
|
25 |
-
@st.cache(allow_output_mutation=True)
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
52 |
images = []
|
53 |
image_ids = []
|
@@ -56,28 +58,7 @@ st.header('Input')
|
|
56 |
with st.form('smooth_grad_form'):
|
57 |
st.markdown('**Model and Input Setting**')
|
58 |
selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
|
59 |
-
# selected_image_set = st.selectbox('Image set', ['Random set', 'User-defined set'])
|
60 |
selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
|
61 |
-
|
62 |
-
# if selected_image_set == 'Class set':
|
63 |
-
# class_labels = imagenet_df.ClassLabel.unique().tolist()
|
64 |
-
# class_labels.sort()
|
65 |
-
# selected_classes = st.multiselect('Class filter', options=['All'] + class_labels)
|
66 |
-
# if not ('All' in selected_classes or len(selected_classes) == 0):
|
67 |
-
# imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
|
68 |
-
# no_images = st.slider('Number of images', 1, len(imagenet_df), value=10)
|
69 |
-
# image_ids = random.sample(imagenet_df.index.tolist(), k=no_images)
|
70 |
-
|
71 |
-
|
72 |
-
# user_defined_button = st.form_submit_button('User-defined set')
|
73 |
-
# random_set_button = st.form_submit_button('Random set')
|
74 |
-
|
75 |
-
# if user_defined_button:
|
76 |
-
# text = st.text_area('Specific Image IDs', value='0')
|
77 |
-
# image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
|
78 |
-
# if random_set_button:
|
79 |
-
# no_images = st.slider('Number of images', 1, 50, value=10)
|
80 |
-
# image_ids = random.sample(list(range(50_000)), k=no_images)
|
81 |
|
82 |
summit_button = st.form_submit_button('Set')
|
83 |
if summit_button:
|
@@ -123,8 +104,11 @@ grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
|
|
123 |
# models[model_name], feature_extractors[model_name] = load_model(model_name)
|
124 |
|
125 |
|
126 |
-
@st.cache(allow_output_mutation=True)
|
127 |
-
|
|
|
|
|
|
|
128 |
return generate_smoothgrad_mask(
|
129 |
image, model_name,
|
130 |
models[model_name], feature_extractors[model_name], num_samples=10)
|
@@ -139,7 +123,7 @@ with _lock:
|
|
139 |
for i, model_name in enumerate(selected_models):
|
140 |
# ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
|
141 |
# model_name, models[model_name], feature_extractors[model_name], num_samples=10)
|
142 |
-
heatmap_image, masked_image = generate_images(
|
143 |
# grids[j][1].image(ori_image)
|
144 |
grids[j][i*2+2].image(heatmap_image)
|
145 |
grids[j][i*2+3].image(masked_image)
|
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import random
|
5 |
+
from backend.utils import make_grid, load_dataset, load_model, load_images
|
6 |
|
7 |
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
|
8 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
|
|
22 |
|
23 |
# --------------------------- LOAD function -----------------------------
|
24 |
|
25 |
+
# @st.cache(allow_output_mutation=True)
|
26 |
+
# @st.cache_data
|
27 |
+
# def load_images(image_ids):
|
28 |
+
# images = []
|
29 |
+
# for image_id in image_ids:
|
30 |
+
# dataset = load_dataset(image_id//10000)
|
31 |
+
# images.append(dataset[image_id%10000])
|
32 |
+
# return images
|
33 |
+
|
34 |
+
# @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
|
35 |
+
# @st.cache_resource
|
36 |
+
# def load_model(model_name):
|
37 |
+
# with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
|
38 |
+
# if model_name == 'ResNet':
|
39 |
+
# model_file_path = 'microsoft/resnet-50'
|
40 |
+
# feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
41 |
+
# model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
42 |
+
# model.eval()
|
43 |
+
# elif model_name == 'ConvNeXt':
|
44 |
+
# model_file_path = 'facebook/convnext-tiny-224'
|
45 |
+
# feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
|
46 |
+
# model = AutoModelForImageClassification.from_pretrained(model_file_path)
|
47 |
+
# model.eval()
|
48 |
+
# else:
|
49 |
+
# model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
|
50 |
+
# model.eval()
|
51 |
+
# feature_extractor = None
|
52 |
+
# return model, feature_extractor
|
53 |
|
54 |
images = []
|
55 |
image_ids = []
|
|
|
58 |
with st.form('smooth_grad_form'):
|
59 |
st.markdown('**Model and Input Setting**')
|
60 |
selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
|
|
|
61 |
selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
summit_button = st.form_submit_button('Set')
|
64 |
if summit_button:
|
|
|
104 |
# models[model_name], feature_extractors[model_name] = load_model(model_name)
|
105 |
|
106 |
|
107 |
+
# @st.cache(allow_output_mutation=True)
|
108 |
+
@st.cache_data
|
109 |
+
def generate_images(image_id, model_name):
|
110 |
+
j = image_ids.index(image_id)
|
111 |
+
image = images[j]['image']
|
112 |
return generate_smoothgrad_mask(
|
113 |
image, model_name,
|
114 |
models[model_name], feature_extractors[model_name], num_samples=10)
|
|
|
123 |
for i, model_name in enumerate(selected_models):
|
124 |
# ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
|
125 |
# model_name, models[model_name], feature_extractors[model_name], num_samples=10)
|
126 |
+
heatmap_image, masked_image = generate_images(image_id, model_name)
|
127 |
# grids[j][1].image(ori_image)
|
128 |
grids[j][i*2+2].image(heatmap_image)
|
129 |
grids[j][i*2+3].image(masked_image)
|
pages/3_Adversarial_attack.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
from backend.utils import make_grid, load_dataset, load_model, load_image
|
6 |
+
|
7 |
+
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img, LoadImage, ShowHeatMap, ShowMaskedImage
|
8 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from matplotlib.backends.backend_agg import RendererAgg
|
12 |
+
|
13 |
+
from backend.adversarial_attack import *
|
14 |
+
|
15 |
+
_lock = RendererAgg.lock
|
16 |
+
|
17 |
+
st.set_page_config(layout='wide')
|
18 |
+
BACKGROUND_COLOR = '#bcd0e7'
|
19 |
+
SECONDARY_COLOR = '#bce7db'
|
20 |
+
|
21 |
+
|
22 |
+
st.title('Adversarial Attack')
|
23 |
+
st.write('How adversarial attacks affect ConvNeXt interpretation?')
|
24 |
+
|
25 |
+
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
|
26 |
+
image_id = None
|
27 |
+
|
28 |
+
if 'image_id' not in st.session_state:
|
29 |
+
st.session_state.image_id = 0
|
30 |
+
|
31 |
+
# def on_change_random_input():
|
32 |
+
# st.session_state.image_id = st.session_state.image_id
|
33 |
+
|
34 |
+
# ----------------------------- INPUT ----------------------------------
|
35 |
+
st.header('Input')
|
36 |
+
input_col_1, input_col_2, input_col_3 = st.columns(3)
|
37 |
+
# --------------------------- INPUT column 1 ---------------------------
|
38 |
+
with input_col_1:
|
39 |
+
with st.form('image_form'):
|
40 |
+
|
41 |
+
# image_id = st.number_input('Image ID: ', format='%d', step=1)
|
42 |
+
st.write('**Choose or generate a random image**')
|
43 |
+
chosen_image_id_input = st.empty()
|
44 |
+
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
45 |
+
|
46 |
+
choose_image_button = st.form_submit_button('Choose the defined image')
|
47 |
+
random_id = st.form_submit_button('Generate a random image')
|
48 |
+
|
49 |
+
if random_id:
|
50 |
+
image_id = random.randint(0, 50000)
|
51 |
+
st.session_state.image_id = image_id
|
52 |
+
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
53 |
+
|
54 |
+
if choose_image_button:
|
55 |
+
image_id = int(image_id)
|
56 |
+
st.session_state.image_id = int(image_id)
|
57 |
+
# st.write(image_id, st.session_state.image_id)
|
58 |
+
|
59 |
+
# ---------------------------- SET UP OUTPUT ------------------------------
|
60 |
+
epsilon_container = st.empty()
|
61 |
+
st.header('Output')
|
62 |
+
st.subheader('Perform attack')
|
63 |
+
|
64 |
+
# perform attack container
|
65 |
+
header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
66 |
+
output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
67 |
+
|
68 |
+
# prediction error container
|
69 |
+
error_container = st.empty()
|
70 |
+
smoothgrad_header_container = st.empty()
|
71 |
+
|
72 |
+
# smoothgrad container
|
73 |
+
smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
|
74 |
+
smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
|
75 |
+
|
76 |
+
original_image_dict = load_image(st.session_state.image_id)
|
77 |
+
input_image = original_image_dict['image']
|
78 |
+
input_label = original_image_dict['label']
|
79 |
+
input_id = original_image_dict['id']
|
80 |
+
|
81 |
+
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
82 |
+
with output_col_1:
|
83 |
+
pred_prob, pred_class_id, pred_class_label = feed_forward(input_image)
|
84 |
+
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
85 |
+
st.image(input_image)
|
86 |
+
header_col_1.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.1f}% confidence')
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
if pred_class_id != (input_id-1):
|
91 |
+
with error_container.container():
|
92 |
+
st.write(f'Predicted output: Class ID {pred_class_id} - {pred_class_label} {pred_prob*100:.1f}% confidence')
|
93 |
+
st.error('ConvNeXt misclassified the chosen image. Please choose or generate another image.',
|
94 |
+
icon = "π«")
|
95 |
+
|
96 |
+
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
97 |
+
with input_col_2:
|
98 |
+
with st.form('epsilon_form'):
|
99 |
+
st.write('**Set epsilon or find the smallest epsilon automatically**')
|
100 |
+
chosen_epsilon_input = st.empty()
|
101 |
+
epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=0.001, format='%.3f', step=0.001)
|
102 |
+
|
103 |
+
epsilon_button = st.form_submit_button('Choose the defined epsilon')
|
104 |
+
find_epsilon = st.form_submit_button('Find the smallest epsilon automatically')
|
105 |
+
|
106 |
+
|
107 |
+
with input_col_3:
|
108 |
+
with st.form('iterate_epsilon_form'):
|
109 |
+
max_epsilon = st.number_input('Maximum value of epsilon (Optional setting)', value=0.500, format='%.3f')
|
110 |
+
step_epsilon = st.number_input('Step (Optional setting)', value=0.001, format='%.3f')
|
111 |
+
setting_button = st.form_submit_button('Set iterating mode')
|
112 |
+
|
113 |
+
|
114 |
+
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
115 |
+
if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
|
116 |
+
with output_col_3:
|
117 |
+
if epsilon_button:
|
118 |
+
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, epsilon)
|
119 |
+
else:
|
120 |
+
epsilons = [i*step_epsilon for i in range(1, 1001) if i*step_epsilon <= max_epsilon]
|
121 |
+
epsilon_container.progress(0, text='Checking epsilon')
|
122 |
+
|
123 |
+
for i, e in enumerate(epsilons):
|
124 |
+
print(e)
|
125 |
+
|
126 |
+
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, e)
|
127 |
+
epsilon_container.progress(i/len(epsilons), text=f'Checking epsilon={e:.3f}. Confidence={new_prob*100:.1f}%')
|
128 |
+
epsilon = e
|
129 |
+
|
130 |
+
if new_id != input_id - 1:
|
131 |
+
epsilon_container.empty()
|
132 |
+
st.balloons()
|
133 |
+
break
|
134 |
+
if i == len(epsilons)-1:
|
135 |
+
epsilon_container.error(f'FSGM failed to attack on this image at epsilon={e:.3f}. Set higher maximum value of epsilon or choose another image',
|
136 |
+
icon = "π«")
|
137 |
+
|
138 |
+
perturbed_image = deprocess_image(perturbed_data.detach().numpy())[0].astype(np.uint8).transpose(1,2,0)
|
139 |
+
perturbed_amount = perturbed_image - input_image
|
140 |
+
header_col_3.write(f'Pertubed amount - epsilon={epsilon:.3f}')
|
141 |
+
st.image(ShowImage(perturbed_amount))
|
142 |
+
|
143 |
+
with output_col_2:
|
144 |
+
# st.write('plus sign')
|
145 |
+
st.image(LoadImage('frontend/images/plus-sign.png'))
|
146 |
+
|
147 |
+
with output_col_4:
|
148 |
+
# st.write('equal sign')
|
149 |
+
st.image(LoadImage('frontend/images/equal-sign.png'))
|
150 |
+
|
151 |
+
# ---------------------------- DISPLAY COL 5 ROW 1 ------------------------------
|
152 |
+
with output_col_5:
|
153 |
+
# st.write(f'ID {new_id+1} - {new_label}: {new_prob*100:.3f}% confidence')
|
154 |
+
st.image(ShowImage(perturbed_image))
|
155 |
+
header_col_5.write(f'Class ID {new_id+1} - {new_label}: {new_prob*100:.1f}% confidence')
|
156 |
+
|
157 |
+
# -------------------------- DISPLAY SMOOTHGRAD ---------------------------
|
158 |
+
smoothgrad_header_container.subheader('SmoothGrad visualization')
|
159 |
+
|
160 |
+
with smoothgrad_col_1:
|
161 |
+
smooth_head_1.write(f'SmoothGrad before attacked')
|
162 |
+
heatmap_image, masked_image, mask = generate_images(st.session_state.image_id, epsilon=0)
|
163 |
+
st.image(heatmap_image)
|
164 |
+
st.image(masked_image)
|
165 |
+
with smoothgrad_col_3:
|
166 |
+
smooth_head_3.write('SmoothGrad after attacked')
|
167 |
+
heatmap_image_attacked, masked_image_attacked, attacked_mask= generate_images(st.session_state.image_id, epsilon=epsilon)
|
168 |
+
st.image(heatmap_image_attacked)
|
169 |
+
st.image(masked_image_attacked)
|
170 |
+
|
171 |
+
with smoothgrad_col_2:
|
172 |
+
st.image(LoadImage('frontend/images/minus-sign-5.png'))
|
173 |
+
|
174 |
+
with smoothgrad_col_5:
|
175 |
+
smooth_head_5.write('SmoothGrad difference')
|
176 |
+
difference_mask = abs(attacked_mask-mask)
|
177 |
+
st.image(ShowHeatMap(difference_mask))
|
178 |
+
masked_image = ShowMaskedImage(difference_mask, perturbed_image)
|
179 |
+
st.image(masked_image)
|
180 |
+
|
181 |
+
with smoothgrad_col_4:
|
182 |
+
st.image(LoadImage('frontend/images/equal-sign.png'))
|
183 |
+
|
184 |
+
|
pages/{3_ImageNet1k.py β 4_ImageNet1k.py}
RENAMED
File without changes
|
requirements.txt
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
captum==0.5.0
|
2 |
-
deta==1.1.0
|
3 |
graphviz==0.20.1
|
4 |
Markdown==3.4.1
|
5 |
matplotlib==3.6.2
|
6 |
numpy==1.22.3
|
7 |
opencv_python_headless==4.6.0.66
|
8 |
pandas==1.5.2
|
9 |
-
Pillow==9.
|
10 |
plotly==5.11.0
|
11 |
-
scipy==1.
|
12 |
setuptools==65.5.0
|
13 |
-
|
14 |
-
streamlit==1.10.0
|
15 |
torch==1.10.1
|
16 |
torchvision==0.11.2
|
17 |
tqdm==4.64.1
|
18 |
transformers==4.25.1
|
|
|
|
1 |
captum==0.5.0
|
|
|
2 |
graphviz==0.20.1
|
3 |
Markdown==3.4.1
|
4 |
matplotlib==3.6.2
|
5 |
numpy==1.22.3
|
6 |
opencv_python_headless==4.6.0.66
|
7 |
pandas==1.5.2
|
8 |
+
Pillow==9.4.0
|
9 |
plotly==5.11.0
|
10 |
+
scipy==1.10.1
|
11 |
setuptools==65.5.0
|
12 |
+
streamlit==1.19.0
|
|
|
13 |
torch==1.10.1
|
14 |
torchvision==0.11.2
|
15 |
tqdm==4.64.1
|
16 |
transformers==4.25.1
|
17 |
+
git+https://github.com/vlue-c/Visual-Explanation-Methods-PyTorch.git
|