Milestone3 / app.py
Jainesh212's picture
Update app.py
d3ae133 verified
raw
history blame
2.97 kB
import streamlit as st
import pandas as pd
import transformers
import torch
import seaborn as sns
import matplotlib.pyplot as plt
# Load the pre-trained BERT model and tokenizer
try:
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
except Exception as e:
st.error(f"Error loading the model: {e}")
# Set up the Streamlit app
st.set_page_config(layout="wide")
st.title('Toxicity Classification App')
# Create a text input for the user to enter their text
text_input = st.text_input('Enter text to classify')
# Create a button to run the classification
if st.button('Classify'):
if not text_input:
st.warning("Please enter text to classify.")
else:
# Tokenize the text and convert to input IDs
encoded_text = tokenizer.encode_plus(
text_input,
max_length=512,
padding='max_length',
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors='pt'
)
# Run the text through the model
with torch.no_grad():
output = model(encoded_text['input_ids'], encoded_text['attention_mask'])
probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0]
# Display the classification results
st.write('Toxic:', probabilities[0])
st.write('Severe Toxic:', probabilities[1])
st.write('Obscene:', probabilities[2])
st.write('Threat:', probabilities[3])
st.write('Insult:', probabilities[4])
st.write('Identity Hate:', probabilities[5])
# Create a DataFrame to store the classification results
results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
results_df = results_df.append({
'Text': text_input,
'Toxic': probabilities[0],
'Severe Toxic': probabilities[1],
'Obscene': probabilities[2],
'Threat': probabilities[3],
'Insult': probabilities[4],
'Identity Hate': probabilities[5]
}, ignore_index=True)
# Append the classification results to the persistent DataFrame
if 'results' not in st.session_state:
st.session_state['results'] = pd.DataFrame(columns=results_df.columns)
st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True)
# Display the persistent DataFrame
st.write('Classification Results:', st.session_state.get('results', pd.DataFrame()))
# Plot the distribution of probabilities for each category
if len(st.session_state.get('results', pd.DataFrame())) > 0:
df = st.session_state['results']
st.pyplot(sns.histplot(data=df, x='Toxic', kde=True))
st.pyplot(sns.histplot(data=df, x='Severe Toxic', kde=True))