Pierce Maloney commited on
Commit
216cf30
1 Parent(s): 60ed7ab

back to normal, removed additional bad words argument

Browse files
Files changed (1) hide show
  1. handler.py +4 -9
handler.py CHANGED
@@ -11,7 +11,7 @@ class EndpointHandler():
11
  self.tokenizer = tokenizer
12
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
13
 
14
- def __call__(self, data: Dict[str, Any], additional_bad_words_ids: List[List[int]] = None) -> List[Dict[str, Any]]:
15
  """
16
  data args:
17
  inputs (:obj: `str`)
@@ -23,9 +23,7 @@ class EndpointHandler():
23
 
24
 
25
  # Bad word: id 3070, 10456 corresponds to "(*", and we do not want to output a comment
26
- bad_words_ids = [[3070], [313, 334], [10456], [29871, 25956, 413, 325, 301, 29889]]
27
- if additional_bad_words_ids:
28
- bad_words_ids.extend(additional_bad_words_ids)
29
 
30
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
31
 
@@ -39,11 +37,8 @@ class EndpointHandler():
39
  stopping_criteria=self.stopping_criteria,
40
  )
41
 
42
- # Slice the generated_ids to only include the new tokens generated, excluding the input tokens
43
- generated_ids = generated_ids[:, input_ids.shape[1]:]
44
-
45
- generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
46
- prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0].tolist()}]
47
  return prediction
48
 
49
 
 
11
  self.tokenizer = tokenizer
12
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
13
 
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
  """
16
  data args:
17
  inputs (:obj: `str`)
 
23
 
24
 
25
  # Bad word: id 3070, 10456 corresponds to "(*", and we do not want to output a comment
26
+ bad_words_ids = [[3070], [313, 334], [10456]]
 
 
27
 
28
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
29
 
 
37
  stopping_criteria=self.stopping_criteria,
38
  )
39
 
40
+ generated_text = self.tokenizer.decode(generated_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
41
+ prediction = [{"generated_text": generated_text}]
 
 
 
42
  return prediction
43
 
44