Spaces:
Runtime error
Runtime error
Jainesh212
commited on
Commit
•
86c0799
1
Parent(s):
df1666f
Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,11 @@ import seaborn as sns
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
|
8 |
# Load the pre-trained BERT model and tokenizer
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
# Set up the Streamlit app
|
13 |
st.set_page_config(layout="wide")
|
@@ -18,45 +21,49 @@ text_input = st.text_input('Enter text to classify')
|
|
18 |
|
19 |
# Create a button to run the classification
|
20 |
if st.button('Classify'):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
|
61 |
# Display the persistent DataFrame
|
62 |
st.write('Classification Results:', st.session_state.get('results', pd.DataFrame()))
|
@@ -64,9 +71,5 @@ st.write('Classification Results:', st.session_state.get('results', pd.DataFrame
|
|
64 |
# Plot the distribution of probabilities for each category
|
65 |
if len(st.session_state.get('results', pd.DataFrame())) > 0:
|
66 |
df = st.session_state['results']
|
67 |
-
|
68 |
-
sns.histplot(data=df, x='Toxic', kde=True
|
69 |
-
axes[0].set_title('Toxic Probability Distribution')
|
70 |
-
sns.histplot(data=df, x='Severe Toxic', kde=True, ax=axes[1])
|
71 |
-
axes[1].set_title('Severe Toxic Probability Distribution')
|
72 |
-
st.pyplot(fig)
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
|
8 |
# Load the pre-trained BERT model and tokenizer
|
9 |
+
try:
|
10 |
+
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
|
11 |
+
model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
|
12 |
+
except Exception as e:
|
13 |
+
st.error(f"Error loading the model: {e}")
|
14 |
|
15 |
# Set up the Streamlit app
|
16 |
st.set_page_config(layout="wide")
|
|
|
21 |
|
22 |
# Create a button to run the classification
|
23 |
if st.button('Classify'):
|
24 |
+
if not text_input:
|
25 |
+
st.warning("Please enter text to classify.")
|
26 |
+
else:
|
27 |
+
# Tokenize the text and convert to input IDs
|
28 |
+
encoded_text = tokenizer.encode_plus(
|
29 |
+
text_input,
|
30 |
+
max_length=512,
|
31 |
+
padding='max_length',
|
32 |
+
truncation=True,
|
33 |
+
add_special_tokens=True,
|
34 |
+
return_attention_mask=True,
|
35 |
+
return_tensors='pt'
|
36 |
+
)
|
37 |
|
38 |
+
# Run the text through the model
|
39 |
+
with torch.no_grad():
|
40 |
+
output = model(encoded_text['input_ids'], encoded_text['attention_mask'])
|
41 |
+
probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0]
|
42 |
|
43 |
+
# Display the classification results
|
44 |
+
st.write('Toxic:', probabilities[0])
|
45 |
+
st.write('Severe Toxic:', probabilities[1])
|
46 |
+
st.write('Obscene:', probabilities[2])
|
47 |
+
st.write('Threat:', probabilities[3])
|
48 |
+
st.write('Insult:', probabilities[4])
|
49 |
+
st.write('Identity Hate:', probabilities[5])
|
50 |
|
51 |
+
# Create a DataFrame to store the classification results
|
52 |
+
results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
|
53 |
+
results_df = results_df.append({
|
54 |
+
'Text': text_input,
|
55 |
+
'Toxic': probabilities[0],
|
56 |
+
'Severe Toxic': probabilities[1],
|
57 |
+
'Obscene': probabilities[2],
|
58 |
+
'Threat': probabilities[3],
|
59 |
+
'Insult': probabilities[4],
|
60 |
+
'Identity Hate': probabilities[5]
|
61 |
+
}, ignore_index=True)
|
62 |
|
63 |
+
# Append the classification results to the persistent DataFrame
|
64 |
+
if 'results' not in st.session_state:
|
65 |
+
st.session_state['results'] = pd.DataFrame(columns=results_df.columns)
|
66 |
+
st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True)
|
67 |
|
68 |
# Display the persistent DataFrame
|
69 |
st.write('Classification Results:', st.session_state.get('results', pd.DataFrame()))
|
|
|
71 |
# Plot the distribution of probabilities for each category
|
72 |
if len(st.session_state.get('results', pd.DataFrame())) > 0:
|
73 |
df = st.session_state['results']
|
74 |
+
st.pyplot(sns.histplot(data=df, x='Toxic', kde=True))
|
75 |
+
st.pyplot(sns.histplot(data=df, x='Severe Toxic', kde=True))
|
|
|
|
|
|
|
|