andreped commited on
Commit
f073f4b
1 Parent(s): 8beaf6e

Developed simple demo for XAI

Browse files
Files changed (5) hide show
  1. .gitignore +10 -0
  2. README.md +13 -1
  3. app.py +99 -0
  4. requirements.txt +4 -0
  5. workflows/deploy.yml +20 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ venv/
2
+ *.jpg
3
+ *.jpeg
4
+ *.png
5
+ flagged/
6
+ *.DS_Store
7
+ *__pycache__/
8
+ *.vs/
9
+ *.idea/
10
+ gradio_cached_examples/
README.md CHANGED
@@ -1 +1,13 @@
1
- # vit-explainer
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 'vit-explainer'
3
+ colorFrom: indigo
4
+ colorTo: indigo
5
+ sdk: gradio
6
+ app_port: 7860
7
+ emoji: 🔥
8
+ pinned: false
9
+ license: mit
10
+ app_file: app.py
11
+ ---
12
+
13
+ # vit-explainer
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import re
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from torch import topk
7
+ from torch.nn.functional import softmax
8
+ from transformers import ViTImageProcessor, ViTForImageClassification
9
+ from transformers_interpret import ImageClassificationExplainer
10
+
11
+
12
+ def load_label_data():
13
+ file_url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
14
+ response = requests.get(file_url)
15
+ labels = []
16
+ pattern = '["\'](.*?)["\']'
17
+ for line in response.text.split('\n'):
18
+ try:
19
+ tmp = re.findall(pattern, line)[0]
20
+ labels.append(tmp)
21
+ except IndexError:
22
+ pass
23
+ return labels
24
+
25
+
26
+ class WebUI:
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.nb_classes = 10
30
+ self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
31
+ self.model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
32
+ self.labels = load_label_data()
33
+
34
+ def run_model(self, image):
35
+ inputs = self.processor(images=image, return_tensors="pt")
36
+ outputs = self.model(**inputs)
37
+ outputs = softmax(outputs.logits, dim=1)
38
+ outputs = topk(outputs, k=self.nb_classes)
39
+ return outputs
40
+
41
+ def classify_image(self, image):
42
+ top10 = self.run_model(image)
43
+ return {self.labels[top10[1][0][i]]: float(top10[0][0][i]) for i in range(self.nb_classes)}
44
+
45
+ def explain_pred(self, image):
46
+ image_classification_explainer = ImageClassificationExplainer(model=self.model, feature_extractor=self.processor)
47
+ saliency = image_classification_explainer(image)
48
+ saliency = np.squeeze(np.moveaxis(saliency, 1, 3))
49
+ saliency[saliency >= 0.05] = 0.05
50
+ saliency[saliency <= -0.05] = -0.05
51
+ return saliency
52
+
53
+ def run(self):
54
+ examples=[
55
+ ['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
56
+ ['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
57
+ ]
58
+ with gr.Blocks() as demo:
59
+ with gr.Row():
60
+ image = gr.Image(height=512)
61
+ label = gr.Label(num_top_classes=self.nb_classes)
62
+ saliency = gr.Image(height=512, label="saliency map", show_label=True)
63
+
64
+ with gr.Column(scale=0.2, min_width=150):
65
+ run_btn = gr.Button("Run analysis", variant="primary", elem_id="run-button")
66
+
67
+ run_btn.click(
68
+ fn=lambda x: self.explain_pred(x),
69
+ inputs=image,
70
+ outputs=saliency,
71
+ )
72
+
73
+ run_btn.click(
74
+ fn=lambda x: self.classify_image(x),
75
+ inputs=image,
76
+ outputs=label,
77
+ )
78
+
79
+ gr.Examples(
80
+ examples=[
81
+ ['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
82
+ ['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
83
+ ],
84
+ inputs=image,
85
+ outputs=image,
86
+ fn=lambda x: x,
87
+ cache_examples=True,
88
+ )
89
+
90
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
91
+
92
+
93
+ def main():
94
+ ui = WebUI()
95
+ ui.run()
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ transformers-interpret
workflows/deploy.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://andreped:[email protected]/spaces/andreped/vit-explainer main