albertvillanova's picture
Fix sst2 dataset name
bde1a4b
|
raw
history blame
2.33 kB
metadata
language: en
license: apache-2.0
datasets:
  - sst2

DistilBERT optimized for Apple Neural Engine

This is the distilbert-base-uncased-finetuned-sst-2-english model, optimized for the Apple Neural Engine (ANE) as described in the article Deploying Transformers on the Apple Neural Engine.

The source code is taken from Apple's ml-ane-transformers GitHub repo, modified slightly to make it usable from the 🤗 Transformers library.

For more details about DistilBERT, we encourage users to check out this model card.

How to use

Usage example:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "apple/ane-distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, trust_remote_code=True, return_dict=False,
)

inputs = tokenizer(
    ["The Neural Engine is really fast"],
    return_tensors="pt",
    max_length=128,
    padding="max_length",
)

with torch.no_grad():
    outputs = model(**inputs)

Using the model with Core ML

PyTorch does not utilize the ANE, and running this version of the model with PyTorch on the CPU or GPU may actually be slower than the original. To take advantage of the hardware acceleration of the ANE, use the Core ML version of the model, DistilBERT_fp16.mlpackage.

Core ML usage example from Python:

import coremltools as ct

mlmodel = ct.models.MLModel("DistilBERT_fp16.mlpackage")

inputs = tokenizer(
    ["The Neural Engine is really fast"],
    return_tensors="np",
    max_length=128,
    padding="max_length",
)

outputs_coreml = mlmodel.predict({
    "input_ids": inputs["input_ids"].astype(np.int32),
    "attention_mask": inputs["attention_mask"].astype(np.int32),
})

To use the model from Swift, you will need to tokenize the input yourself according to the BERT rules. You can find a Swift implementation of the BERT tokenizer here.