Jainesh212 commited on
Commit
86c0799
1 Parent(s): df1666f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -43
app.py CHANGED
@@ -6,8 +6,11 @@ import seaborn as sns
6
  import matplotlib.pyplot as plt
7
 
8
  # Load the pre-trained BERT model and tokenizer
9
- tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
10
- model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
 
 
 
11
 
12
  # Set up the Streamlit app
13
  st.set_page_config(layout="wide")
@@ -18,45 +21,49 @@ text_input = st.text_input('Enter text to classify')
18
 
19
  # Create a button to run the classification
20
  if st.button('Classify'):
21
- # Tokenize the text and convert to input IDs
22
- encoded_text = tokenizer.encode_plus(
23
- text_input,
24
- max_length=512,
25
- padding='max_length',
26
- truncation=True,
27
- add_special_tokens=True,
28
- return_attention_mask=True,
29
- return_tensors='pt'
30
- )
 
 
 
31
 
32
- # Run the text through the model
33
- with torch.no_grad():
34
- output = model(encoded_text['input_ids'], encoded_text['attention_mask'])
35
- probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0]
36
 
37
- # Display the classification results
38
- st.write('Toxic:', probabilities[0])
39
- st.write('Severe Toxic:', probabilities[1])
40
- st.write('Obscene:', probabilities[2])
41
- st.write('Threat:', probabilities[3])
42
- st.write('Insult:', probabilities[4])
43
- st.write('Identity Hate:', probabilities[5])
44
 
45
- # Create a DataFrame to store the classification results
46
- results_df = pd.DataFrame({
47
- 'Text': [text_input],
48
- 'Toxic': [probabilities[0]],
49
- 'Severe Toxic': [probabilities[1]],
50
- 'Obscene': [probabilities[2]],
51
- 'Threat': [probabilities[3]],
52
- 'Insult': [probabilities[4]],
53
- 'Identity Hate': [probabilities[5]]
54
- })
 
55
 
56
- # Append the classification results to the persistent DataFrame
57
- if 'results' not in st.session_state:
58
- st.session_state['results'] = pd.DataFrame(columns=results_df.columns)
59
- st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True)
60
 
61
  # Display the persistent DataFrame
62
  st.write('Classification Results:', st.session_state.get('results', pd.DataFrame()))
@@ -64,9 +71,5 @@ st.write('Classification Results:', st.session_state.get('results', pd.DataFrame
64
  # Plot the distribution of probabilities for each category
65
  if len(st.session_state.get('results', pd.DataFrame())) > 0:
66
  df = st.session_state['results']
67
- fig, axes = plt.subplots(ncols=2, figsize=(12, 6))
68
- sns.histplot(data=df, x='Toxic', kde=True, ax=axes[0])
69
- axes[0].set_title('Toxic Probability Distribution')
70
- sns.histplot(data=df, x='Severe Toxic', kde=True, ax=axes[1])
71
- axes[1].set_title('Severe Toxic Probability Distribution')
72
- st.pyplot(fig)
 
6
  import matplotlib.pyplot as plt
7
 
8
  # Load the pre-trained BERT model and tokenizer
9
+ try:
10
+ tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
11
+ model = transformers.BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
12
+ except Exception as e:
13
+ st.error(f"Error loading the model: {e}")
14
 
15
  # Set up the Streamlit app
16
  st.set_page_config(layout="wide")
 
21
 
22
  # Create a button to run the classification
23
  if st.button('Classify'):
24
+ if not text_input:
25
+ st.warning("Please enter text to classify.")
26
+ else:
27
+ # Tokenize the text and convert to input IDs
28
+ encoded_text = tokenizer.encode_plus(
29
+ text_input,
30
+ max_length=512,
31
+ padding='max_length',
32
+ truncation=True,
33
+ add_special_tokens=True,
34
+ return_attention_mask=True,
35
+ return_tensors='pt'
36
+ )
37
 
38
+ # Run the text through the model
39
+ with torch.no_grad():
40
+ output = model(encoded_text['input_ids'], encoded_text['attention_mask'])
41
+ probabilities = torch.nn.functional.softmax(output[0], dim=1).tolist()[0]
42
 
43
+ # Display the classification results
44
+ st.write('Toxic:', probabilities[0])
45
+ st.write('Severe Toxic:', probabilities[1])
46
+ st.write('Obscene:', probabilities[2])
47
+ st.write('Threat:', probabilities[3])
48
+ st.write('Insult:', probabilities[4])
49
+ st.write('Identity Hate:', probabilities[5])
50
 
51
+ # Create a DataFrame to store the classification results
52
+ results_df = pd.DataFrame(columns=['Text', 'Toxic', 'Severe Toxic', 'Obscene', 'Threat', 'Insult', 'Identity Hate'])
53
+ results_df = results_df.append({
54
+ 'Text': text_input,
55
+ 'Toxic': probabilities[0],
56
+ 'Severe Toxic': probabilities[1],
57
+ 'Obscene': probabilities[2],
58
+ 'Threat': probabilities[3],
59
+ 'Insult': probabilities[4],
60
+ 'Identity Hate': probabilities[5]
61
+ }, ignore_index=True)
62
 
63
+ # Append the classification results to the persistent DataFrame
64
+ if 'results' not in st.session_state:
65
+ st.session_state['results'] = pd.DataFrame(columns=results_df.columns)
66
+ st.session_state['results'] = st.session_state['results'].append(results_df, ignore_index=True)
67
 
68
  # Display the persistent DataFrame
69
  st.write('Classification Results:', st.session_state.get('results', pd.DataFrame()))
 
71
  # Plot the distribution of probabilities for each category
72
  if len(st.session_state.get('results', pd.DataFrame())) > 0:
73
  df = st.session_state['results']
74
+ st.pyplot(sns.histplot(data=df, x='Toxic', kde=True))
75
+ st.pyplot(sns.histplot(data=df, x='Severe Toxic', kde=True))