daniel-de-leon commited on
Commit
5f60a0a
1 Parent(s): 675a95c

Add inference and explanation run times

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
4
  pipeline)
5
  import shap
6
  from PIL import Image
 
7
 
8
  st.set_option('deprecation.showPyplotGlobalUse', False)
9
  output_width = 800
@@ -33,16 +34,26 @@ tokenizer, model = load_model(model_name)
33
  pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
34
  explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
35
 
36
- col1, col2 = st.columns(2)
37
  text = col1.text_area("Enter text input", value = "Classify me.")
38
 
 
39
  result = pred(text)
 
 
 
 
 
40
  top_pred = result[0][0]['label']
41
  col2.write('')
42
  for label in result[0]:
43
  col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
44
 
45
  shap_values = explainer([text])
 
 
 
 
46
 
47
  force_plot = shap.plots.text(shap_values, display=False)
48
  bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
@@ -58,4 +69,4 @@ st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Pr
58
  st.pyplot(bar_plot, clear_figure=True)
59
 
60
  st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
61
- components.html(force_plot, height=output_height, width=output_width, scrolling=True)
 
4
  pipeline)
5
  import shap
6
  from PIL import Image
7
+ import time
8
 
9
  st.set_option('deprecation.showPyplotGlobalUse', False)
10
  output_width = 800
 
34
  pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
35
  explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
36
 
37
+ col1, col2, col3 = st.columns(3)
38
  text = col1.text_area("Enter text input", value = "Classify me.")
39
 
40
+ start_time = time.time()
41
  result = pred(text)
42
+ inference_time = time.time() - start_time
43
+
44
+ col3.write('')
45
+ col3.write(f'**Inference Time:** {inference_time: .4f}')
46
+
47
  top_pred = result[0][0]['label']
48
  col2.write('')
49
  for label in result[0]:
50
  col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
51
 
52
  shap_values = explainer([text])
53
+ explanation_time = shap_values.compute_time
54
+
55
+ col3.write('')
56
+ col3.write(f'**Explanation Time:** {explanation_time: .4f}')
57
 
58
  force_plot = shap.plots.text(shap_values, display=False)
59
  bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
 
69
  st.pyplot(bar_plot, clear_figure=True)
70
 
71
  st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
72
+ components.html(force_plot, height=output_height, width=output_width, scrolling=True)