shhossain's picture
Update sentpipeline.py
267bbc9 verified
raw
history blame
685 Bytes
from transformers import Pipeline
from sentence_transformers import SentenceTransformer
import torch
class SentimentModelPipe(Pipeline):
def __init__(self, **kwargs):
Pipeline.__init__(self, **kwargs)
self.smodel = SentenceTransformer(kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2"))
def _sanitize_parameters(self, **kw):
return {}, {}, {}
def preprocess(self, inputs):
return self.smodel.encode(inputs, convert_to_tensor=True)
def postprocess(self, outputs):
return outputs
def _forward(self, tensor):
with torch.no_grad():
out = self.model(tensor)
return out