File size: 993 Bytes
5b64c43 ce2b713 5b64c43 ce2b713 5b64c43 ce2b713 5b64c43 ce2b713 c64309c ce2b713 5b64c43 ce2b713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from transformers import Pipeline
from sentence_transformers import SentenceTransformer
import torch
class SentimentModelPipe(Pipeline):
def __init__(self, **kwargs):
Pipeline.__init__(self, **kwargs)
self.smodel = SentenceTransformer(
kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")
)
self.class_map = kwargs.get(
"class_map",
{0: "sad", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"},
)
def _sanitize_parameters(self, **kw):
return {}, {}, {}
def preprocess(self, inputs):
return self.smodel.encode(inputs, convert_to_tensor=True)
def postprocess(self, outputs):
results = []
for i, l in enumerate(outputs):
results.append({"label": self.class_map[i], "score": l.item()})
return results
def _forward(self, tensor):
with torch.no_grad():
out = self.model(tensor)
return out
|