File size: 685 Bytes
5b64c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267bbc9
5b64c43
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import Pipeline
from sentence_transformers import SentenceTransformer
import torch

class SentimentModelPipe(Pipeline):
    def __init__(self, **kwargs):
        Pipeline.__init__(self, **kwargs)
        self.smodel = SentenceTransformer(kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2"))

    def _sanitize_parameters(self, **kw):
        return {}, {}, {}

    def preprocess(self, inputs):
        return self.smodel.encode(inputs, convert_to_tensor=True)
    
    def postprocess(self, outputs):
        return outputs
    
    def _forward(self, tensor):
        with torch.no_grad():
            out = self.model(tensor)
        return out