Edit model card

Overview

This is a multi-label, multi-class linear classifer for emotions that works with sentence-transformers/all-MiniLM-L12-v2, having been trained on the go_emotions dataset.

Labels

The 28 labels from the go_emotions dataset are:

['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']

Metrics (exact match of labels per item)

This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification. Evaluating across all labels per item in the go_emotions test split the metrics are shown below.

Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are:

  • Precision: 0.378
  • Recall: 0.438
  • F1: 0.394

Weighted by the relative support of each label in the dataset, this is:

  • Precision: 0.424
  • Recall: 0.590
  • F1: 0.481

Using a fixed threshold of 0.5 to convert the scores to binary predictions for each label, the metrics (evaluated on the go_emotions test split, and unweighted by support) are:

  • Precision: 0.568
  • Recall: 0.214
  • F1: 0.260

Metrics (per-label)

This is a multi-label, multi-class dataset, so each label is effectively a separate binary classification and metrics are better measured per label.

Optimising the threshold per label to optimise the F1 metric, the metrics (evaluated on the go_emotions test split) are:

f1 precision recall support threshold
admiration 0.540 0.463 0.649 504 0.20
amusement 0.686 0.669 0.705 264 0.25
anger 0.419 0.373 0.480 198 0.15
annoyance 0.276 0.189 0.512 320 0.10
approval 0.299 0.260 0.350 351 0.15
caring 0.303 0.219 0.489 135 0.10
confusion 0.284 0.269 0.301 153 0.15
curiosity 0.365 0.310 0.444 284 0.15
desire 0.274 0.237 0.325 83 0.15
disappointment 0.188 0.292 0.139 151 0.20
disapproval 0.305 0.257 0.375 267 0.15
disgust 0.450 0.462 0.439 123 0.20
embarrassment 0.348 0.375 0.324 37 0.30
excitement 0.313 0.306 0.320 103 0.20
fear 0.550 0.505 0.603 78 0.25
gratitude 0.776 0.774 0.778 352 0.30
grief 0.353 0.273 0.500 6 0.70
joy 0.370 0.361 0.379 161 0.20
love 0.626 0.717 0.555 238 0.35
nervousness 0.308 0.276 0.348 23 0.55
optimism 0.436 0.432 0.441 186 0.20
pride 0.444 0.545 0.375 16 0.60
realization 0.171 0.146 0.207 145 0.10
relief 0.133 0.250 0.091 11 0.60
remorse 0.468 0.426 0.518 56 0.30
sadness 0.413 0.409 0.417 156 0.20
surprise 0.314 0.303 0.326 141 0.15
neutral 0.622 0.482 0.879 1787 0.25

The thesholds are stored in thresholds.json.

Use with ONNXRuntime

The input to the model is called logits, and there is one output per label. Each output produces a 2d array, with 1 row per input row, and each row having 2 columns - the first being a proba output for the negative case, and the second being a proba output for the positive case.

# Assuming you have embeddings from all-MiniLM-L12-v2 for the input sentences
# E.g. produced from sentence-transformers such as:
#      huggingface.co/sentence-transformers/all-MiniLM-L12-v2
#      or from an ONNX version E.g. huggingface.co/Xenova/all-MiniLM-L12-v2

print(embeddings.shape)  # E.g. a batch of 1 sentence
> (1, 384)

import onnxruntime as ort

sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])

outputs = [o.name for o in sess.get_outputs()]  # list of labels, in the order of the outputs
preds_onnx = sess.run(_outputs, {'logits': embeddings})
# preds_onnx is a list with 28 entries, one per label,
# each with a numpy array of shape (1, 2) given the input was a batch of 1

print(outputs[0])
> surprise
print(preds_onnx[0])
> array([[0.97136074, 0.02863926]], dtype=float32)

# load thresholds.json and use that (per label) to convert the positive case score to a binary prediction 

Commentary on the dataset

Some labels (E.g. gratitude) when considered independently perform very strongly, whilst others (E.g. relief) perform very poorly.

This is a challenging dataset. Labels such as relief do have much fewer examples in the training data (less than 100 out of the 40k+, and only 11 in the test split).

But there is also some ambiguity and/or labelling errors visible in the training data of go_emotions that is suspected to constrain the performance. Data cleaning on the dataset to reduce some of the mistakes, ambiguity, conflicts and duplication in the labelling would produce a higher performing model.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Examples
Inference API (serverless) has been turned off for this model.

Dataset used to train SamLowe/all-MiniLM-L12-v2-go_emotions-classifier-onnx