File size: 1,946 Bytes
fd77815 528da04 39e1615 fd77815 29406f8 fd77815 cb4608c fd77815 94c9a58 fd77815 528da04 94c9a58 528da04 fd77815 390d16b 39e1615 4a0592e b081db9 2d942ee 4a0592e bf5dc75 7b977eb fd77815 39e1615 2d942ee 7b977eb cb4608c 7b977eb 94c9a58 7b977eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, pipeline
from transformers import (
TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
)
st.title("Detecting Toxic Tweets")
demo = """Your words are like poison. They seep into my mind and make me feel worthless."""
text = st.text_area("Input Text", demo, height=250)
model_options = {
"DistilBERT Base Uncased (SST-2)": "distilbert-base-uncased-finetuned-sst-2-english",
"Fine-tuned Toxicity Model": "RobCaamano/toxicity",
}
selected_model = st.selectbox("Select Model", options=list(model_options.keys()))
mod_name = model_options[selected_model]
tokenizer = AutoTokenizer.from_pretrained(mod_name)
model = AutoModelForSequenceClassification.from_pretrained(mod_name)
clf = pipeline(
"sentiment-analysis", model=model, tokenizer=tokenizer, return_all_scores=True
)
if selected_model in ["Fine-tuned Toxicity Model"]:
toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}
def get_toxicity_class(predictions, threshold=0.3):
return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
input = tokenizer(text, return_tensors="tf")
if st.button("Submit", type="primary"):
results = dict(d.values() for d in clf(text)[0])
toxic_labels = {k: results[k] for k in results.keys() if not k == "toxic"}
tweet_portion = text[:50] + "..." if len(text) > 50 else text
if len(toxic_labels) == 0:
st.write("This text is not toxic.")
else:
df = pd.DataFrame(
{
"Text (portion)": [tweet_portion] * len(toxic_labels),
"Toxicity Class": list(toxic_labels.keys()),
"Probability": list(toxic_labels.values()),
}
)
st.table(df)
|