Milestone3 / app.py
Jainesh212's picture
Update app.py
de71ace
raw
history blame
2.28 kB
import streamlit as st
import transformers
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
# Load the pre-trained BERT model
model_name = 'nlptown/bert-base-multilingual-uncased-sentiment'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, framework='pt', task='text-classification')
# Define the toxicity classification function
def classify_toxicity(text):
result = pipeline(text)[0]
label = result['label']
score = result['score']
return label, score
# Define the Streamlit app
def app():
# Create a persistent DataFrame
if 'results' not in st.session_state:
st.session_state.results = pd.DataFrame(columns=['text', 'toxicity', 'score'])
# Set page title and favicon
st.set_page_config(page_title='Toxicity Classification App', page_icon=':guardsman:')
# Set app header
st.write('# Toxicity Classification App')
st.write('Enter some text and the app will classify its toxicity.')
# Create a form for users to enter their text
with st.form(key='text_form'):
text_input = st.text_input(label='Enter your text:')
submit_button = st.form_submit_button(label='Classify')
# Classify the text and display the results
if submit_button and text_input != '':
label, score = classify_toxicity(text_input)
st.write('## Classification Result')
st.write(f'**Text:** {text_input}')
st.write(f'**Toxicity:** {label}')
st.write(f'**Score:** {score:.2f}')
# Add the classification result to the persistent DataFrame
st.session_state.results = st.session_state.results.append({'text': text_input, 'toxicity': label, 'score': score}, ignore_index=True)
# Display the persistent DataFrame
st.write('## Classification Results')
st.write(st.session_state.results)
# Display a chart of the classification results
chart_data = st.session_state.results.groupby('toxicity').size().reset_index(name='count')
chart = st.bar_chart(chart_data.set_index('toxicity'))
if __name__ == '__main__':
app()