import streamlit as st import pandas as pd import transformers import torch import seaborn as sns import matplotlib.pyplot as plt # Load the pre-trained BERT model and tokenizer try: tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased') model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6) except Exception as e: st.error(f"Error loading the model: {e}") # Set up the Streamlit app st.set_page_config(layout="wide") st.title('Toxicity Classification App') # Create a text input for the user to enter their text text_input = st.text_input('Enter text to classify') # Create a button to run the classification if st.button('Classify'): if not text_input: st.warning("Please enter text to classify.") else: # Tokenize the text and convert to input IDs encoded_text = tokenizer.encode_plus( text_input, max_length=512, padding='max_length', truncation=True, add_special_tokens=True, return_attention_mask=True, return_tensors='pt' ) # Run the text through the model with torch.no_grad(): output = model(encoded_text['input_ids'], encoded_text['attention_mask']) probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0] # Display the classification results st.write('Toxic:', probabilities[0]) st.write('Severe Toxic:', probabilities[1]) st.write('Obscene:', probabilities[2]) st.write('Threat:', probabilities[3]) st.write('Insult:', probabilities[4]) st.write('Identity Hate:', probabilities[5]) # Create a DataFrame to store the classification results results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate']) results_df = results_df.append({ 'Text': text_input, 'Toxic': probabilities[0], 'Severe Toxic': probabilities[1], 'Obscene': probabilities[2], 'Threat': probabilities[3], 'Insult': probabilities[4], 'Identity Hate': probabilities[5] }, ignore_index=True) # Append the classification results to the persistent DataFrame if 'results' not in st.session_state: st.session_state['results'] = pd.DataFrame(columns=results_df.columns) st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True) # Display the persistent DataFrame st.write('Classification Results:', st.session_state.get('results', pd.DataFrame())) # Plot the distribution of probabilities for each category if len(st.session_state.get('results', pd.DataFrame())) > 0: df = st.session_state['results'] st.pyplot(sns.histplot(data=df, x='Toxic', kde=True)) st.pyplot(sns.histplot(data=df, x='Severe Toxic', kde=True))