saattrupdan commited on
Commit
316b0d8
1 Parent(s): 90ed2ce

feat: Remove SHAP

Browse files
Files changed (2) hide show
  1. app.py +3 -39
  2. requirements.txt +0 -1
app.py CHANGED
@@ -5,8 +5,6 @@ from numba.core.errors import NumbaDeprecationWarning
5
  warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
6
  import gradio as gr
7
  from transformers import pipeline
8
- from shap import Explainer
9
- import numpy as np
10
  from typing import Tuple, Dict, List
11
 
12
 
@@ -30,47 +28,13 @@ def main():
30
 
31
  def classification(text) -> Tuple[Dict[str, float], dict]:
32
  output: List[dict] = pipe(text)[0]
33
- print(output)
 
34
 
35
- explainer = Explainer(pipe)
36
- explanation = explainer([text])
37
- shap_values = explanation.values[0].sum(axis=1)
38
-
39
- # Find the SHAP boundary
40
- boundary = 0.03
41
- if np.abs(shap_values).max() <= boundary:
42
- boundary = np.abs(shap_values).max() - 1e-6
43
-
44
- words: List[str] = explanation.data[0]
45
- records = list()
46
- char_idx = 0
47
- for word, shap_value in zip(words, shap_values):
48
-
49
- if abs(shap_value) <= boundary:
50
- entity = 'O'
51
- else:
52
- entity = output['label'].lower().replace(' ', '-')
53
-
54
- if len(word):
55
- start = char_idx
56
- char_idx += len(word)
57
- end = char_idx
58
- records.append(dict(
59
- entity=entity,
60
- word=word,
61
- score=abs(shap_value),
62
- start=start,
63
- end=end,
64
- ))
65
- print(records)
66
-
67
- return ({output["label"]: output["score"]}, dict(text=text, entities=records))
68
-
69
- color_map = {"offensive": "red", "not-offensive": "green", 'O': 'white'}
70
  demo = gr.Interface(
71
  fn=classification,
72
  inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]),
73
- outputs=[gr.Label(), gr.HighlightedText().style(color_map=color_map)],
74
  examples=examples,
75
  title="Danish Offensive Text Detection",
76
  description="""
 
5
  warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
6
  import gradio as gr
7
  from transformers import pipeline
 
 
8
  from typing import Tuple, Dict, List
9
 
10
 
 
28
 
29
  def classification(text) -> Tuple[Dict[str, float], dict]:
30
  output: List[dict] = pipe(text)[0]
31
+ print(text, output)
32
+ return {output["label"]: output["score"]}
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  demo = gr.Interface(
35
  fn=classification,
36
  inputs=gr.Textbox(placeholder="Enter sentence here...", value=examples[0]),
37
+ outputs=gr.Label(),
38
  examples=examples,
39
  title="Danish Offensive Text Detection",
40
  description="""
requirements.txt CHANGED
@@ -88,7 +88,6 @@ rfc3986==1.5.0
88
  scikit-learn==1.2.2
89
  scipy==1.10.1
90
  semantic-version==2.10.0
91
- shap==0.41.0
92
  six==1.16.0
93
  slicer==0.0.7
94
  sniffio==1.3.0
 
88
  scikit-learn==1.2.2
89
  scipy==1.10.1
90
  semantic-version==2.10.0
 
91
  six==1.16.0
92
  slicer==0.0.7
93
  sniffio==1.3.0