|
import streamlit as st |
|
import altair as alt |
|
import torch |
|
from transformers import AlbertTokenizer, AlbertForSequenceClassification |
|
import sentencepiece as spm |
|
import pandas as pd |
|
|
|
|
|
model_name = "albert-base-v2" |
|
tokenizer = AlbertTokenizer.from_pretrained(model_name) |
|
model = AlbertForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
def classify_text(text): |
|
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits.detach().numpy()[0] |
|
probabilities = torch.softmax(torch.tensor(logits), dim=0).tolist() |
|
return probabilities |
|
|
|
|
|
st.title("ALBERT Text Classification App") |
|
|
|
|
|
default_text = "Streamlit-Altair: A component that allows the creation of Altair visualizations within Streamlit.\nStreamlit-Bokeh: A component that allows the creation of Bokeh visualizations within Streamlit.\nStreamlit-Plotly: A component that allows the creation of Plotly visualizations within Streamlit.\nStreamlit-Mapbox: A component that allows the creation of Mapbox maps within Streamlit.\nStreamlit-DeckGL: A component that allows the creation of Deck.GL visualizations within Streamlit.\nStreamlit-Wordcloud: A component that allows the creation of word clouds within Streamlit.\nStreamlit-Audio: A component that allows the playing of audio files within Streamlit.\nStreamlit-Video: A component that allows the playing of video files within Streamlit.\nStreamlit-EmbedCode: A component that allows the embedding of code snippets within Streamlit.\nStreamlit-Components: A component that provides a library of custom Streamlit components created by the Streamlit community." |
|
text_input = st.text_area("Enter text to classify", default_text, height=200) |
|
|
|
|
|
|
|
if st.button("Classify"): |
|
if text_input: |
|
probabilities = classify_text(text_input) |
|
df = pd.DataFrame({ |
|
'Label': ['Negative', 'Positive'], |
|
'Probability': probabilities |
|
}) |
|
chart = alt.Chart(df).mark_bar().encode( |
|
x='Probability', |
|
y=alt.Y('Label', sort=['Negative', 'Positive']) |
|
) |
|
st.write(chart) |
|
else: |
|
st.write("Please enter some text to classify.") |
|
|