|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
import gradio as gr |
|
|
|
|
|
models = { |
|
"DistilBERT": { |
|
"onnx_model_path": "distilbert.onnx", |
|
"tokenizer_name": "distilbert-base-multilingual-cased", |
|
}, |
|
"BERT": { |
|
"onnx_model_path": "bert.onnx", |
|
"tokenizer_name": "bert-base-multilingual-cased", |
|
}, |
|
"MuRIL": { |
|
"onnx_model_path": "muril.onnx", |
|
"tokenizer_name": "google/muril-base-cased", |
|
}, |
|
"RoBERTa": { |
|
"onnx_model_path": "roberta.onnx", |
|
"tokenizer_name": "cardiffnlp/twitter-roberta-base-emotion", |
|
}, |
|
} |
|
|
|
|
|
model_sessions = {} |
|
tokenizers = {} |
|
|
|
for model_name, config in models.items(): |
|
print(f"Loading {model_name}...") |
|
model_sessions[model_name] = ort.InferenceSession(config["onnx_model_path"]) |
|
tokenizers[model_name] = AutoTokenizer.from_pretrained(config["tokenizer_name"]) |
|
|
|
print("All models loaded!") |
|
|
|
|
|
def predict_with_model(text, model_name): |
|
|
|
ort_session = model_sessions[model_name] |
|
tokenizer = tokenizers[model_name] |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True) |
|
|
|
|
|
outputs = ort_session.run(None, { |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"], |
|
}) |
|
|
|
|
|
logits = outputs[0] |
|
label = "Hate Speech" if logits[0][1] > logits[0][0] else "Not Hate Speech" |
|
return label |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_with_model, |
|
inputs=[ |
|
gr.Textbox(label="Enter text to classify"), |
|
gr.Dropdown( |
|
choices=list(models.keys()), |
|
label="Select a model", |
|
), |
|
], |
|
outputs="text", |
|
title="Multi-Model Hate Speech Detection", |
|
description="Choose a model and enter text to classify whether it's hate speech.", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|