import streamlit as st import streamlit.components.v1 as components from transformers import (AutoModelForSequenceClassification, AutoTokenizer, pipeline) import shap from PIL import Image st.set_option('deprecation.showPyplotGlobalUse', False) output_width = 800 output_height = 300 rescale_logits = False st.set_page_config(page_title='Text Classification with Shap') st.title('Interpreting HF Pipeline Text Classification with Shap') form = st.sidebar.form("Model Selection") form.header('Model Selection') model_name = form.text_input("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain") form.form_submit_button("Submit") @st.cache_data() def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(model_name) return tokenizer, model tokenizer, model = load_model(model_name) pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None) explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits) col1, col2 = st.columns(2) text = col1.text_area("Enter text input", value = "Classify me.") result = pred(text) top_pred = result[0][0]['label'] col2.write('') for label in result[0]: col2.write(f'**{label["label"]}**: {label["score"]: .2f}') shap_values = explainer([text]) force_plot = shap.plots.text(shap_values, display=False) bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False) st.markdown(""" """, unsafe_allow_html=True) st.markdown(f'
Shap Bar Plot for {top_pred} Prediction
Shap Interactive Force Plot