|
from transformers import Pipeline |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
|
|
class SentimentModelPipe(Pipeline): |
|
def __init__(self, **kwargs): |
|
Pipeline.__init__(self, **kwargs) |
|
self.smodel = SentenceTransformer(kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")) |
|
|
|
def _sanitize_parameters(self, **kw): |
|
return {}, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
return self.smodel.encode(inputs, convert_to_tensor=True) |
|
|
|
def postprocess(self, outputs): |
|
return outputs |
|
|
|
def _forward(self, tensor): |
|
with torch.no_grad(): |
|
out = self.model(tensor) |
|
return out |