Spaces:
Runtime error
Runtime error
import streamlit as st | |
import transformers | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline | |
# Load the pre-trained BERT model | |
model_name = 'nlptown/bert-base-multilingual-uncased-sentiment' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, framework='pt', task='text-classification') | |
# Define the toxicity classification function | |
def classify_toxicity(text): | |
result = pipeline(text)[0] | |
label = result['label'] | |
score = result['score'] | |
return label, score | |
# Define the Streamlit app | |
def app(): | |
# Create a persistent DataFrame | |
if 'results' not in st.session_state: | |
st.session_state.results = pd.DataFrame(columns=['text', 'toxicity', 'score']) | |
# Set page title and favicon | |
st.set_page_config(page_title='Toxicity Classification App', page_icon=':guardsman:') | |
# Set app header | |
st.write('# Toxicity Classification App') | |
st.write('Enter some text and the app will classify its toxicity.') | |
# Create a form for users to enter their text | |
with st.form(key='text_form'): | |
text_input = st.text_input(label='Enter your text:') | |
submit_button = st.form_submit_button(label='Classify') | |
# Classify the text and display the results | |
if submit_button and text_input != '': | |
label, score = classify_toxicity(text_input) | |
st.write('## Classification Result') | |
st.write(f'**Text:** {text_input}') | |
st.write(f'**Toxicity:** {label}') | |
st.write(f'**Score:** {score:.2f}') | |
# Add the classification result to the persistent DataFrame | |
st.session_state.results = st.session_state.results.append({'text': text_input, 'toxicity': label, 'score': score}, ignore_index=True) | |
# Display the persistent DataFrame | |
st.write('## Classification Results') | |
st.write(st.session_state.results) | |
# Display a chart of the classification results | |
chart_data = st.session_state.results.groupby('toxicity').size().reset_index(name='count') | |
chart = st.bar_chart(chart_data.set_index('toxicity')) | |
if __name__ == '__main__': | |
app() | |