import onnxruntime as ort from transformers import AutoTokenizer import gradio as gr # Define available models with their ONNX file paths and tokenizer names models = { "DistilBERT": { "onnx_model_path": "distilbert.onnx", "tokenizer_name": "distilbert-base-multilingual-cased", }, "BERT": { "onnx_model_path": "bert.onnx", "tokenizer_name": "bert-base-multilingual-cased", }, "MuRIL": { "onnx_model_path": "muril.onnx", "tokenizer_name": "google/muril-base-cased", }, "RoBERTa": { "onnx_model_path": "roberta.onnx", "tokenizer_name": "cardiffnlp/twitter-roberta-base-emotion", }, } # Load models and tokenizers into memory model_sessions = {} tokenizers = {} for model_name, config in models.items(): print(f"Loading {model_name}...") model_sessions[model_name] = ort.InferenceSession(config["onnx_model_path"]) tokenizers[model_name] = AutoTokenizer.from_pretrained(config["tokenizer_name"]) print("All models loaded!") # Prediction function def predict_with_model(text, model_name): # Select the appropriate ONNX session and tokenizer ort_session = model_sessions[model_name] tokenizer = tokenizers[model_name] # Tokenize the input text inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True) # Run ONNX inference outputs = ort_session.run(None, { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], }) # Post-process the output logits = outputs[0] label = "Hate Speech" if logits[0][1] > logits[0][0] else "Not Hate Speech" return label # Define Gradio interface interface = gr.Interface( fn=predict_with_model, inputs=[ gr.Textbox(label="Enter text to classify"), gr.Dropdown( choices=list(models.keys()), label="Select a model", ), ], outputs="text", title="Multi-Model Hate Speech Detection", description="Choose a model and enter text to classify whether it's hate speech.", ) # Launch the app if __name__ == "__main__": interface.launch()