File size: 993 Bytes
5b64c43
 
 
 
ce2b713
5b64c43
 
 
ce2b713
 
 
 
 
 
 
5b64c43
 
 
 
 
 
ce2b713
5b64c43
ce2b713
c64309c
 
ce2b713
 
5b64c43
 
 
ce2b713
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import Pipeline
from sentence_transformers import SentenceTransformer
import torch


class SentimentModelPipe(Pipeline):
    def __init__(self, **kwargs):
        Pipeline.__init__(self, **kwargs)
        self.smodel = SentenceTransformer(
            kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")
        )
        self.class_map = kwargs.get(
            "class_map",
            {0: "sad", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"},
        )

    def _sanitize_parameters(self, **kw):
        return {}, {}, {}

    def preprocess(self, inputs):
        return self.smodel.encode(inputs, convert_to_tensor=True)

    def postprocess(self, outputs):
        results = []
        for i, l in enumerate(outputs):
            results.append({"label": self.class_map[i], "score": l.item()})
        return results

    def _forward(self, tensor):
        with torch.no_grad():
            out = self.model(tensor)
        return out