osanseviero commited on
Commit
bab71b5
1 Parent(s): 9ae440b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +10 -16
pipeline.py CHANGED
@@ -7,30 +7,24 @@ import fasttext.util
7
  class PreTrainedPipeline():
8
 
9
  def __init__(self, path=""):
10
-
11
  """
12
-
13
  Initialize model
14
-
15
  """
16
 
17
  self.model = fasttext.load_model(os.path.join(path, 'debate2vec.bin'))
18
 
19
- def __call__(self, inputs: str) -> List[float]:
20
-
21
  """
22
-
23
  Args:
24
-
25
  inputs (:obj:`str`):
26
-
27
- a string to get the features of.
28
-
29
  Return:
30
-
31
- A :obj:`list` of floats: The features computed by the model.
32
-
33
  """
34
-
35
- return self.model.get_sentence_vector(inputs).tolist()
36
-
 
 
 
7
  class PreTrainedPipeline():
8
 
9
  def __init__(self, path=""):
 
10
  """
 
11
  Initialize model
 
12
  """
13
 
14
  self.model = fasttext.load_model(os.path.join(path, 'debate2vec.bin'))
15
 
16
+ def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
 
17
  """
 
18
  Args:
 
19
  inputs (:obj:`str`):
20
+ a string containing some text
 
 
21
  Return:
22
+ A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
23
+ - "label": A string representing what the label/class is. There can be multiple labels.
24
+ - "score": A score between 0 and 1 describing how confident the model is for this label/class.
25
  """
26
+ preds = self.model.get_nearest_neighbors("dog", k=10)
27
+ result = []
28
+ for distance, word in preds:
29
+ result.append({"label": word, "score": distance})
30
+ return [result]