Update sentpipeline.py
Browse files- sentpipeline.py +20 -5
sentpipeline.py
CHANGED
@@ -2,21 +2,36 @@ from transformers import Pipeline
|
|
2 |
from sentence_transformers import SentenceTransformer
|
3 |
import torch
|
4 |
|
|
|
5 |
class SentimentModelPipe(Pipeline):
|
6 |
def __init__(self, **kwargs):
|
7 |
Pipeline.__init__(self, **kwargs)
|
8 |
-
self.smodel = SentenceTransformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def _sanitize_parameters(self, **kw):
|
11 |
return {}, {}, {}
|
12 |
|
13 |
def preprocess(self, inputs):
|
14 |
return self.smodel.encode(inputs, convert_to_tensor=True)
|
15 |
-
|
16 |
def postprocess(self, outputs):
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def _forward(self, tensor):
|
20 |
with torch.no_grad():
|
21 |
out = self.model(tensor)
|
22 |
-
return out
|
|
|
2 |
from sentence_transformers import SentenceTransformer
|
3 |
import torch
|
4 |
|
5 |
+
|
6 |
class SentimentModelPipe(Pipeline):
|
7 |
def __init__(self, **kwargs):
|
8 |
Pipeline.__init__(self, **kwargs)
|
9 |
+
self.smodel = SentenceTransformer(
|
10 |
+
kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")
|
11 |
+
)
|
12 |
+
self.class_map = kwargs.get(
|
13 |
+
"class_map",
|
14 |
+
{0: "sad", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"},
|
15 |
+
)
|
16 |
|
17 |
def _sanitize_parameters(self, **kw):
|
18 |
return {}, {}, {}
|
19 |
|
20 |
def preprocess(self, inputs):
|
21 |
return self.smodel.encode(inputs, convert_to_tensor=True)
|
22 |
+
|
23 |
def postprocess(self, outputs):
|
24 |
+
if isinstance(outputs, torch.Tensor):
|
25 |
+
outputs = [outputs]
|
26 |
+
results = []
|
27 |
+
for out in outputs:
|
28 |
+
r = []
|
29 |
+
for i, l in enumerate(out):
|
30 |
+
r.append({"label": self.class_map[i], "score": l.item()})
|
31 |
+
results.append(r)
|
32 |
+
return results
|
33 |
+
|
34 |
def _forward(self, tensor):
|
35 |
with torch.no_grad():
|
36 |
out = self.model(tensor)
|
37 |
+
return out
|