|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
model_dir, |
|
ignore_mismatched_sizes=True |
|
) |
|
|
|
|
|
self.pipeline = pipeline( |
|
"text-classification", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
def __call__(self, inputs): |
|
|
|
predictions = self.pipeline(inputs) |
|
return predictions |
|
|
|
|
|
def get_pipeline(model_dir): |
|
return EndpointHandler(model_dir) |