--- datasets: - seanius/toxic-or-neutral-text-labelled language: - en library_name: transformers base_model: distilbert/distilbert-base-uncased --- ONNX model - a fine tuned version of DistilBERT which can be used to classify text as one of: - neutral, offensive_language, harmful_behaviour, hate_speech The model was trained using the [csfy tool](https://github.com/mrseanryan/csfy) and the dataset [seanius/toxic-or-neutral-text-labelled](https://huggingface.co/datasets/seanius/toxic-or-neutral-text-labelled) The base model is required (distilbert-base-uncased) For an example of how to run the model, see below - or see the [csfy tool](https://github.com/mrseanryan/csfy). The output is a number indicating the class - it is decoded via the label_mapping.json file. # Usage ```python # Loading the label mappings import json def load_label_mappings(): with open("./label_mapping.json", encoding="utf-8") as f: data = json.load(f) return data['labels'] label_mappings = load_label_mappings() # Loading the model import onnxruntime as ort from transformers import DistilBertTokenizer tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') ort_session = ort.InferenceSession("./toxic-or-neutral-text-labelled.onnx") # Predicting label for given text def predict_via_onnx(text, ort_session, tokenizer, label_mappings): model_expected_input_shape = ort_session.get_inputs()[0].shape print("Model expects input shape:", model_expected_input_shape) inputs = tokenizer(text, return_tensors="np", padding="max_length", truncation=True, max_length=model_expected_input_shape[1]) print("input shape", inputs['input_ids'].shape) input_ids = inputs['input_ids'] if input_ids.ndim == 1: input_ids = input_ids[np.newaxis, :] ort_inputs = {ort_session.get_inputs()[0].name: input_ids} ort_inputs['input_ids'] = ort_inputs['input_ids'].astype(np.int64) ort_outputs = ort_session.run(None, ort_inputs) predictions = np.argmax(ort_outputs, axis=-1) predicted_label = label_mappings[predictions.item()] return predicted_label predicted_label = predict_via_onnx("How do I get to the beach?", ort_session, tokenizer, label_mappings) print(predicted_label) ``` --- license: mit ---