Pierce Maloney commited on
Commit
000ad8b
1 Parent(s): 833b301

using .generate, returning ids, custom bad_words

Browse files
Files changed (1) hide show
  1. handler.py +23 -11
handler.py CHANGED
@@ -2,17 +2,16 @@ from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
 
4
 
5
-
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
  # Preload all the elements you are going to need at inference.
9
  tokenizer = AutoTokenizer.from_pretrained(path)
10
- model = AutoModelForCausalLM.from_pretrained(path)
11
  tokenizer.pad_token = tokenizer.eos_token
12
- self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
 
13
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
14
 
15
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
17
  data args:
18
  inputs (:obj: `str`)
@@ -22,16 +21,29 @@ class EndpointHandler():
22
  """
23
  inputs = data.pop("inputs", data)
24
 
25
- # Bad word: id 3070 corresponds to "(*", and we do not want to output a comment
26
- prediction = self.pipeline(
27
- inputs,
28
- stopping_criteria=self.stopping_criteria,
29
- max_new_tokens=50,
30
- return_full_text=False,
31
- bad_words_ids=[[3070], [313, 334], [10456]],
 
 
 
 
 
 
32
  temperature=1,
33
  top_k=40,
 
34
  )
 
 
 
 
 
 
35
  return prediction
36
 
37
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
 
4
 
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
  # Preload all the elements you are going to need at inference.
8
  tokenizer = AutoTokenizer.from_pretrained(path)
 
9
  tokenizer.pad_token = tokenizer.eos_token
10
+ self.model = AutoModelForCausalLM.from_pretrained(path)
11
+ self.tokenizer = tokenizer
12
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
13
 
14
+ def __call__(self, data: Dict[str, Any], additional_bad_words_ids: List[List[int]] = None) -> List[Dict[str, Any]]:
15
  """
16
  data args:
17
  inputs (:obj: `str`)
 
21
  """
22
  inputs = data.pop("inputs", data)
23
 
24
+
25
+ # Bad word: id 3070, 10456 corresponds to "(*", and we do not want to output a comment
26
+ bad_words_ids = [[3070], [313, 334], [10456]]
27
+ if additional_bad_words_ids:
28
+ bad_words_ids.extend(additional_bad_words_ids)
29
+
30
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
31
+
32
+ # Generate text using model.generate
33
+ generated_ids = self.model.generate(
34
+ input_ids,
35
+ max_length=input_ids.shape[1] + 50, # 50 new tokens
36
+ bad_words_ids=bad_words_ids,
37
  temperature=1,
38
  top_k=40,
39
+ stopping_criteria=self.stopping_criteria,
40
  )
41
+
42
+ # Slice the generated_ids to only include the new tokens generated, excluding the input tokens
43
+ generated_ids = generated_ids[:, input_ids.shape[1]:]
44
+
45
+ generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
46
+ prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0].tolist()}]
47
  return prediction
48
 
49