shhossain commited on
Commit
ce2b713
1 Parent(s): 267bbc9

Update sentpipeline.py

Browse files
Files changed (1) hide show
  1. sentpipeline.py +20 -5
sentpipeline.py CHANGED
@@ -2,21 +2,36 @@ from transformers import Pipeline
2
  from sentence_transformers import SentenceTransformer
3
  import torch
4
 
 
5
  class SentimentModelPipe(Pipeline):
6
  def __init__(self, **kwargs):
7
  Pipeline.__init__(self, **kwargs)
8
- self.smodel = SentenceTransformer(kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2"))
 
 
 
 
 
 
9
 
10
  def _sanitize_parameters(self, **kw):
11
  return {}, {}, {}
12
 
13
  def preprocess(self, inputs):
14
  return self.smodel.encode(inputs, convert_to_tensor=True)
15
-
16
  def postprocess(self, outputs):
17
- return outputs
18
-
 
 
 
 
 
 
 
 
19
  def _forward(self, tensor):
20
  with torch.no_grad():
21
  out = self.model(tensor)
22
- return out
 
2
  from sentence_transformers import SentenceTransformer
3
  import torch
4
 
5
+
6
  class SentimentModelPipe(Pipeline):
7
  def __init__(self, **kwargs):
8
  Pipeline.__init__(self, **kwargs)
9
+ self.smodel = SentenceTransformer(
10
+ kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")
11
+ )
12
+ self.class_map = kwargs.get(
13
+ "class_map",
14
+ {0: "sad", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"},
15
+ )
16
 
17
  def _sanitize_parameters(self, **kw):
18
  return {}, {}, {}
19
 
20
  def preprocess(self, inputs):
21
  return self.smodel.encode(inputs, convert_to_tensor=True)
22
+
23
  def postprocess(self, outputs):
24
+ if isinstance(outputs, torch.Tensor):
25
+ outputs = [outputs]
26
+ results = []
27
+ for out in outputs:
28
+ r = []
29
+ for i, l in enumerate(out):
30
+ r.append({"label": self.class_map[i], "score": l.item()})
31
+ results.append(r)
32
+ return results
33
+
34
  def _forward(self, tensor):
35
  with torch.no_grad():
36
  out = self.model(tensor)
37
+ return out