Pierce Maloney commited on
Commit
833b301
1 Parent(s): 88fdb99

adding new bad word

Browse files
Files changed (2) hide show
  1. handler.py +9 -15
  2. test_tokenizer +0 -0
handler.py CHANGED
@@ -7,9 +7,9 @@ class EndpointHandler():
7
  def __init__(self, path=""):
8
  # Preload all the elements you are going to need at inference.
9
  tokenizer = AutoTokenizer.from_pretrained(path)
10
- self.tokenizer = tokenizer
11
- self.model = AutoModelForCausalLM.from_pretrained(path)
12
- self.tokenizer.pad_token = tokenizer.eos_token
13
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -21,24 +21,18 @@ class EndpointHandler():
21
  A :obj:`list` | `dict`: will be serialized and returned
22
  """
23
  inputs = data.pop("inputs", data)
24
- input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
25
 
26
  # Bad word: id 3070 corresponds to "(*", and we do not want to output a comment
27
- prediction_ids = self.model.generate(
28
- input_ids,
29
- max_length=input_ids.shape[1] + 50,
30
  stopping_criteria=self.stopping_criteria,
31
- bad_words_ids=[[3070], [313, 334]],
 
 
32
  temperature=1,
33
  top_k=40,
34
- # pad_token_id=self.tokenizer.eos_token_id,
35
- # return_dict_in_generate=True, # To get more detailed output (optional)
36
  )
37
-
38
- # Decode the generated ids to text
39
- # Exclude the input_ids length to get only the new tokens
40
- prediction_text = self.tokenizer.decode(prediction_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
41
- return [{"generated_text": prediction_text, "ids": prediction_ids[0, input_ids.shape[1]:].tolist()}]
42
 
43
 
44
  class StopAtPeriodCriteria(StoppingCriteria):
 
7
  def __init__(self, path=""):
8
  # Preload all the elements you are going to need at inference.
9
  tokenizer = AutoTokenizer.from_pretrained(path)
10
+ model = AutoModelForCausalLM.from_pretrained(path)
11
+ tokenizer.pad_token = tokenizer.eos_token
12
+ self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
13
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
21
  A :obj:`list` | `dict`: will be serialized and returned
22
  """
23
  inputs = data.pop("inputs", data)
 
24
 
25
  # Bad word: id 3070 corresponds to "(*", and we do not want to output a comment
26
+ prediction = self.pipeline(
27
+ inputs,
 
28
  stopping_criteria=self.stopping_criteria,
29
+ max_new_tokens=50,
30
+ return_full_text=False,
31
+ bad_words_ids=[[3070], [313, 334], [10456]],
32
  temperature=1,
33
  top_k=40,
 
 
34
  )
35
+ return prediction
 
 
 
 
36
 
37
 
38
  class StopAtPeriodCriteria(StoppingCriteria):
test_tokenizer ADDED
File without changes