File size: 1,946 Bytes
fd77815
528da04
39e1615
fd77815
 
 
29406f8
fd77815
 
cb4608c
fd77815
94c9a58
fd77815
528da04
 
94c9a58
528da04
 
 
 
fd77815
390d16b
 
39e1615
 
 
4a0592e
b081db9
2d942ee
 
4a0592e
bf5dc75
7b977eb
fd77815
 
 
 
39e1615
 
2d942ee
7b977eb
cb4608c
7b977eb
94c9a58
7b977eb
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, pipeline
from transformers import (
    TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
)

st.title("Detecting Toxic Tweets")

demo = """Your words are like poison. They seep into my mind and make me feel worthless."""

text = st.text_area("Input Text", demo, height=250)

model_options = {
    "DistilBERT Base Uncased (SST-2)": "distilbert-base-uncased-finetuned-sst-2-english",
    "Fine-tuned Toxicity Model": "RobCaamano/toxicity",
}
selected_model = st.selectbox("Select Model", options=list(model_options.keys()))

mod_name = model_options[selected_model]

tokenizer = AutoTokenizer.from_pretrained(mod_name)
model = AutoModelForSequenceClassification.from_pretrained(mod_name)
clf = pipeline(
    "sentiment-analysis", model=model, tokenizer=tokenizer, return_all_scores=True
)

if selected_model in ["Fine-tuned Toxicity Model"]:
    toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
    model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}

def get_toxicity_class(predictions, threshold=0.3):
    return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}

input = tokenizer(text, return_tensors="tf")

if st.button("Submit", type="primary"):
    results = dict(d.values() for d in clf(text)[0])
    toxic_labels = {k: results[k] for k in results.keys() if not k == "toxic"}

    tweet_portion = text[:50] + "..." if len(text) > 50 else text

    if len(toxic_labels) == 0:
        st.write("This text is not toxic.")
    else:
        df = pd.DataFrame(
            {
                "Text (portion)": [tweet_portion] * len(toxic_labels),
                "Toxicity Class": list(toxic_labels.keys()),
                "Probability": list(toxic_labels.values()),
            }
        )
        st.table(df)