nbroad's picture
nbroad HF staff
Update handler.py
44ee688 verified
raw
history blame
1.49 kB
from typing import Dict, List, Any
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
from optimum.pipelines import pipeline
import torch
if torch.backends.cudnn.is_available():
print("cudnn:", torch.backends.cudnn.version())
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
model = ORTModelForSequenceClassification.from_pretrained(
path,
export=False,
provider="CUDAExecutionProvider",
)
tokenizer = AutoTokenizer.from_pretrained(path)
# create inference pipeline
self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", dict())
prediction = self.pipeline(inputs, **parameters)
return prediction