saattrupdan
commited on
Commit
•
316b0d8
1
Parent(s):
90ed2ce
feat: Remove SHAP
Browse files- app.py +3 -39
- requirements.txt +0 -1
app.py
CHANGED
@@ -5,8 +5,6 @@ from numba.core.errors import NumbaDeprecationWarning
|
|
5 |
warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
|
6 |
import gradio as gr
|
7 |
from transformers import pipeline
|
8 |
-
from shap import Explainer
|
9 |
-
import numpy as np
|
10 |
from typing import Tuple, Dict, List
|
11 |
|
12 |
|
@@ -30,47 +28,13 @@ def main():
|
|
30 |
|
31 |
def classification(text) -> Tuple[Dict[str, float], dict]:
|
32 |
output: List[dict] = pipe(text)[0]
|
33 |
-
print(output)
|
|
|
34 |
|
35 |
-
explainer = Explainer(pipe)
|
36 |
-
explanation = explainer([text])
|
37 |
-
shap_values = explanation.values[0].sum(axis=1)
|
38 |
-
|
39 |
-
# Find the SHAP boundary
|
40 |
-
boundary = 0.03
|
41 |
-
if np.abs(shap_values).max() <= boundary:
|
42 |
-
boundary = np.abs(shap_values).max() - 1e-6
|
43 |
-
|
44 |
-
words: List[str] = explanation.data[0]
|
45 |
-
records = list()
|
46 |
-
char_idx = 0
|
47 |
-
for word, shap_value in zip(words, shap_values):
|
48 |
-
|
49 |
-
if abs(shap_value) <= boundary:
|
50 |
-
entity = 'O'
|
51 |
-
else:
|
52 |
-
entity = output['label'].lower().replace(' ', '-')
|
53 |
-
|
54 |
-
if len(word):
|
55 |
-
start = char_idx
|
56 |
-
char_idx += len(word)
|
57 |
-
end = char_idx
|
58 |
-
records.append(dict(
|
59 |
-
entity=entity,
|
60 |
-
word=word,
|
61 |
-
score=abs(shap_value),
|
62 |
-
start=start,
|
63 |
-
end=end,
|
64 |
-
))
|
65 |
-
print(records)
|
66 |
-
|
67 |
-
return ({output["label"]: output["score"]}, dict(text=text, entities=records))
|
68 |
-
|
69 |
-
color_map = {"offensive": "red", "not-offensive": "green", 'O': 'white'}
|
70 |
demo = gr.Interface(
|
71 |
fn=classification,
|
72 |
inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]),
|
73 |
-
outputs=
|
74 |
examples=examples,
|
75 |
title="Danish Offensive Text Detection",
|
76 |
description="""
|
|
|
5 |
warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
|
6 |
import gradio as gr
|
7 |
from transformers import pipeline
|
|
|
|
|
8 |
from typing import Tuple, Dict, List
|
9 |
|
10 |
|
|
|
28 |
|
29 |
def classification(text) -> Tuple[Dict[str, float], dict]:
|
30 |
output: List[dict] = pipe(text)[0]
|
31 |
+
print(text, output)
|
32 |
+
return {output["label"]: output["score"]}
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
demo = gr.Interface(
|
35 |
fn=classification,
|
36 |
inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]),
|
37 |
+
outputs=gr.Label(),
|
38 |
examples=examples,
|
39 |
title="Danish Offensive Text Detection",
|
40 |
description="""
|
requirements.txt
CHANGED
@@ -88,7 +88,6 @@ rfc3986==1.5.0
|
|
88 |
scikit-learn==1.2.2
|
89 |
scipy==1.10.1
|
90 |
semantic-version==2.10.0
|
91 |
-
shap==0.41.0
|
92 |
six==1.16.0
|
93 |
slicer==0.0.7
|
94 |
sniffio==1.3.0
|
|
|
88 |
scikit-learn==1.2.2
|
89 |
scipy==1.10.1
|
90 |
semantic-version==2.10.0
|
|
|
91 |
six==1.16.0
|
92 |
slicer==0.0.7
|
93 |
sniffio==1.3.0
|