File size: 1,021 Bytes
7d810d6 948e8bf 7d810d6 c8186c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
class EndpointHandler:
def __init__(self, model_dir):
# Load the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Load the model with the `ignore_mismatched_sizes` flag
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir,
ignore_mismatched_sizes=True
)
# Initialize the pipeline
self.pipeline = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1 # Use GPU if available
)
def __call__(self, inputs):
# Perform inference using the pipeline
predictions = self.pipeline(inputs)
return predictions
# Function to be called by Hugging Face Inference Toolkit
def get_pipeline(model_dir):
return EndpointHandler(model_dir) |