from transformers import Pipeline | |
from sentence_transformers import SentenceTransformer | |
import torch | |
class SentimentModelPipe(Pipeline): | |
def __init__(self, **kwargs): | |
Pipeline.__init__(self, **kwargs) | |
self.smodel = SentenceTransformer(kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")) | |
def _sanitize_parameters(self, **kw): | |
return {}, {}, {} | |
def preprocess(self, inputs): | |
return self.smodel.encode(inputs, convert_to_tensor=True) | |
def postprocess(self, outputs): | |
return outputs.argmax(1).item() | |
def _forward(self, tensor): | |
with torch.no_grad(): | |
out = self.model(tensor) | |
return out |