daniel-de-leon
commited on
Commit
•
5f60a0a
1
Parent(s):
675a95c
Add inference and explanation run times
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
|
|
4 |
pipeline)
|
5 |
import shap
|
6 |
from PIL import Image
|
|
|
7 |
|
8 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
9 |
output_width = 800
|
@@ -33,16 +34,26 @@ tokenizer, model = load_model(model_name)
|
|
33 |
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
|
34 |
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
|
35 |
|
36 |
-
col1, col2 = st.columns(
|
37 |
text = col1.text_area("Enter text input", value = "Classify me.")
|
38 |
|
|
|
39 |
result = pred(text)
|
|
|
|
|
|
|
|
|
|
|
40 |
top_pred = result[0][0]['label']
|
41 |
col2.write('')
|
42 |
for label in result[0]:
|
43 |
col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
|
44 |
|
45 |
shap_values = explainer([text])
|
|
|
|
|
|
|
|
|
46 |
|
47 |
force_plot = shap.plots.text(shap_values, display=False)
|
48 |
bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
|
@@ -58,4 +69,4 @@ st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Pr
|
|
58 |
st.pyplot(bar_plot, clear_figure=True)
|
59 |
|
60 |
st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
|
61 |
-
components.html(force_plot, height=output_height, width=output_width, scrolling=True)
|
|
|
4 |
pipeline)
|
5 |
import shap
|
6 |
from PIL import Image
|
7 |
+
import time
|
8 |
|
9 |
st.set_option('deprecation.showPyplotGlobalUse', False)
|
10 |
output_width = 800
|
|
|
34 |
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
|
35 |
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
|
36 |
|
37 |
+
col1, col2, col3 = st.columns(3)
|
38 |
text = col1.text_area("Enter text input", value = "Classify me.")
|
39 |
|
40 |
+
start_time = time.time()
|
41 |
result = pred(text)
|
42 |
+
inference_time = time.time() - start_time
|
43 |
+
|
44 |
+
col3.write('')
|
45 |
+
col3.write(f'**Inference Time:** {inference_time: .4f}')
|
46 |
+
|
47 |
top_pred = result[0][0]['label']
|
48 |
col2.write('')
|
49 |
for label in result[0]:
|
50 |
col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
|
51 |
|
52 |
shap_values = explainer([text])
|
53 |
+
explanation_time = shap_values.compute_time
|
54 |
+
|
55 |
+
col3.write('')
|
56 |
+
col3.write(f'**Explanation Time:** {explanation_time: .4f}')
|
57 |
|
58 |
force_plot = shap.plots.text(shap_values, display=False)
|
59 |
bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
|
|
|
69 |
st.pyplot(bar_plot, clear_figure=True)
|
70 |
|
71 |
st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
|
72 |
+
components.html(force_plot, height=output_height, width=output_width, scrolling=True)
|