File size: 2,375 Bytes
3a10454
 
 
 
c08ef85
24be11a
3a10454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24be11a
 
 
3a10454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import streamlit as st
import altair as alt
import torch
from transformers import AlbertTokenizer, AlbertForSequenceClassification
import sentencepiece as spm
import pandas as pd

# Load pre-trained model and tokenizer
model_name = "albert-base-v2"
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForSequenceClassification.from_pretrained(model_name)

# Define function to classify input text
def classify_text(text):
    inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits.detach().numpy()[0]
    probabilities = torch.softmax(torch.tensor(logits), dim=0).tolist()
    return probabilities

# Set up Streamlit app
st.title("ALBERT Text Classification App")

# Create input box for user to enter text
default_text = "Streamlit-Altair: A component that allows the creation of Altair visualizations within Streamlit.\nStreamlit-Bokeh: A component that allows the creation of Bokeh visualizations within Streamlit.\nStreamlit-Plotly: A component that allows the creation of Plotly visualizations within Streamlit.\nStreamlit-Mapbox: A component that allows the creation of Mapbox maps within Streamlit.\nStreamlit-DeckGL: A component that allows the creation of Deck.GL visualizations within Streamlit.\nStreamlit-Wordcloud: A component that allows the creation of word clouds within Streamlit.\nStreamlit-Audio: A component that allows the playing of audio files within Streamlit.\nStreamlit-Video: A component that allows the playing of video files within Streamlit.\nStreamlit-EmbedCode: A component that allows the embedding of code snippets within Streamlit.\nStreamlit-Components: A component that provides a library of custom Streamlit components created by the Streamlit community."
text_input = st.text_area("Enter text to classify", default_text, height=200)


# Classify input text and display results
if st.button("Classify"):
    if text_input:
        probabilities = classify_text(text_input)
        df = pd.DataFrame({
            'Label': ['Negative', 'Positive'],
            'Probability': probabilities
        })
        chart = alt.Chart(df).mark_bar().encode(
            x='Probability',
            y=alt.Y('Label', sort=['Negative', 'Positive'])
        )
        st.write(chart)
    else:
        st.write("Please enter some text to classify.")