Pierce Maloney commited on
Commit
dee492f
1 Parent(s): 813fd4a

back to .generate. forgot about returning gen ids

Browse files
Files changed (1) hide show
  1. handler.py +34 -73
handler.py CHANGED
@@ -2,95 +2,56 @@ from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
 
4
 
5
- # class EndpointHandler():
6
- # def __init__(self, path=""):
7
- # tokenizer = AutoTokenizer.from_pretrained(path)
8
- # tokenizer.pad_token = tokenizer.eos_token
9
- # self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
10
- # self.tokenizer = tokenizer
11
- # self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
12
-
13
- # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
- # """
15
- # data args:
16
- # inputs (:obj: `str`)
17
- # kwargs
18
- # Return:
19
- # A :obj:`list` | `dict`: will be serialized and returned
20
- # """
21
- # inputs = data.pop("inputs", data)
22
- # additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
23
-
24
-
25
- # # 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
26
- # # 13 is a newline character
27
- # # [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
28
- # # [2087, 29885, 4430, 29889] is "Admitted."
29
- # bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
30
- # bad_words_ids.extend(additional_bad_words_ids)
31
-
32
- # input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
33
- # max_generation_length = 75 # Desired number of tokens to generate
34
- # # max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
35
-
36
- # # # Truncate input_ids to the most recent tokens that fit within the max_input_length
37
- # # if input_ids.shape[1] > max_input_length:
38
- # # input_ids = input_ids[:, -max_input_length:]
39
-
40
- # max_length = input_ids.shape[1] + max_generation_length
41
-
42
- # generated_ids = self.model.generate(
43
- # input_ids,
44
- # max_length=max_length, # 50 new tokens
45
- # bad_words_ids=bad_words_ids,
46
- # temperature=1,
47
- # top_k=40,
48
- # do_sample=True,
49
- # stopping_criteria=self.stopping_criteria,
50
- # )
51
-
52
- # generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
53
- # prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
54
- # return prediction
55
-
56
  class EndpointHandler():
57
  def __init__(self, path=""):
58
- self.model_path = path
59
  tokenizer = AutoTokenizer.from_pretrained(path)
60
  tokenizer.pad_token = tokenizer.eos_token
 
61
  self.tokenizer = tokenizer
62
- # Initialize the pipeline for text generation
63
- self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0) # device=0 for CUDA
64
 
65
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
66
  """
67
- data args:
68
- inputs (:obj: `str`)
69
- kwargs
70
- Return:
71
- A :obj:`list` | `dict`: will be serialized and returned
72
  """
73
  inputs = data.pop("inputs", data)
74
  additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
75
 
76
- # Define bad words to avoid in the output
 
 
 
 
77
  bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
78
  bad_words_ids.extend(additional_bad_words_ids)
79
 
80
- # Generate text using the pipeline
81
- generation_kwargs = {
82
- "max_new_tokens": 75,
83
- "temperature": 0.7,
84
- "top_k": 40,
85
- "bad_words_ids": bad_words_ids,
86
- "pad_token_id": self.tokenizer.eos_token_id,
87
- "return_full_text": False, # Only return the new generated tokens
88
- }
89
- generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)
90
 
91
- # Format the output
92
- predictions = [{"generated_text": output["generated_text"]} for output in generated_outputs]
93
- return predictions
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  class StopAtPeriodCriteria(StoppingCriteria):
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
 
7
  tokenizer = AutoTokenizer.from_pretrained(path)
8
  tokenizer.pad_token = tokenizer.eos_token
9
+ self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
10
  self.tokenizer = tokenizer
11
+ self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
 
12
 
13
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
  """
15
+ data args:
16
+ inputs (:obj: `str`)
17
+ kwargs
18
+ Return:
19
+ A :obj:`list` | `dict`: will be serialized and returned
20
  """
21
  inputs = data.pop("inputs", data)
22
  additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
23
 
24
+
25
+ # 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
26
+ # 13 is a newline character
27
+ # [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
28
+ # [2087, 29885, 4430, 29889] is "Admitted."
29
  bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
30
  bad_words_ids.extend(additional_bad_words_ids)
31
 
32
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
33
+ max_generation_length = 75 # Desired number of tokens to generate
34
+ # max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
35
+
36
+ # # Truncate input_ids to the most recent tokens that fit within the max_input_length
37
+ # if input_ids.shape[1] > max_input_length:
38
+ # input_ids = input_ids[:, -max_input_length:]
 
 
 
39
 
40
+ max_length = input_ids.shape[1] + max_generation_length
41
+
42
+ generated_ids = self.model.generate(
43
+ input_ids,
44
+ max_length=max_length, # 50 new tokens
45
+ bad_words_ids=bad_words_ids,
46
+ temperature=1,
47
+ top_k=40,
48
+ do_sample=True,
49
+ stopping_criteria=self.stopping_criteria,
50
+ )
51
+
52
+ generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
53
+ prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
54
+ return prediction
55
 
56
 
57
  class StopAtPeriodCriteria(StoppingCriteria):