Pierce Maloney commited on
Commit
a36be93
1 Parent(s): b873ed7

pipeline trial

Browse files
Files changed (1) hide show
  1. handler.py +75 -37
handler.py CHANGED
@@ -1,59 +1,97 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
- import torch
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
 
8
  tokenizer = AutoTokenizer.from_pretrained(path)
9
- tokenizer.pad_token = tokenizer.eos_token
10
- self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
11
  self.tokenizer = tokenizer
12
- self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
 
13
 
14
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
  """
16
- data args:
17
- inputs (:obj: `str`)
18
- kwargs
19
- Return:
20
- A :obj:`list` | `dict`: will be serialized and returned
21
  """
22
- torch.cuda.empty_cache()
23
  inputs = data.pop("inputs", data)
24
  additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
25
 
26
-
27
- # 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
28
- # 13 is a newline character
29
- # [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
30
- # [2087, 29885, 4430, 29889] is "Admitted."
31
  bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
32
  bad_words_ids.extend(additional_bad_words_ids)
33
 
34
- input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
35
- max_generation_length = 75 # Desired number of tokens to generate
36
- # max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
37
-
38
- # # Truncate input_ids to the most recent tokens that fit within the max_input_length
39
- # if input_ids.shape[1] > max_input_length:
40
- # input_ids = input_ids[:, -max_input_length:]
 
 
41
 
42
- max_length = input_ids.shape[1] + max_generation_length
43
-
44
- generated_ids = self.model.generate(
45
- input_ids,
46
- max_length=max_length, # 50 new tokens
47
- bad_words_ids=bad_words_ids,
48
- temperature=1,
49
- top_k=40,
50
- do_sample=True,
51
- stopping_criteria=self.stopping_criteria,
52
- )
53
-
54
- generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
55
- prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
56
- return prediction
57
 
58
 
59
  class StopAtPeriodCriteria(StoppingCriteria):
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
 
3
 
4
 
5
+ # class EndpointHandler():
6
+ # def __init__(self, path=""):
7
+ # tokenizer = AutoTokenizer.from_pretrained(path)
8
+ # tokenizer.pad_token = tokenizer.eos_token
9
+ # self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
10
+ # self.tokenizer = tokenizer
11
+ # self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
12
+
13
+ # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
+ # """
15
+ # data args:
16
+ # inputs (:obj: `str`)
17
+ # kwargs
18
+ # Return:
19
+ # A :obj:`list` | `dict`: will be serialized and returned
20
+ # """
21
+ # inputs = data.pop("inputs", data)
22
+ # additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
23
+
24
+
25
+ # # 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
26
+ # # 13 is a newline character
27
+ # # [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
28
+ # # [2087, 29885, 4430, 29889] is "Admitted."
29
+ # bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
30
+ # bad_words_ids.extend(additional_bad_words_ids)
31
+
32
+ # input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
33
+ # max_generation_length = 75 # Desired number of tokens to generate
34
+ # # max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
35
+
36
+ # # # Truncate input_ids to the most recent tokens that fit within the max_input_length
37
+ # # if input_ids.shape[1] > max_input_length:
38
+ # # input_ids = input_ids[:, -max_input_length:]
39
+
40
+ # max_length = input_ids.shape[1] + max_generation_length
41
+
42
+ # generated_ids = self.model.generate(
43
+ # input_ids,
44
+ # max_length=max_length, # 50 new tokens
45
+ # bad_words_ids=bad_words_ids,
46
+ # temperature=1,
47
+ # top_k=40,
48
+ # do_sample=True,
49
+ # stopping_criteria=self.stopping_criteria,
50
+ # )
51
+
52
+ # generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
53
+ # prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
54
+ # return prediction
55
+
56
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
57
+
58
  class EndpointHandler():
59
  def __init__(self, path=""):
60
+ self.model_path = path
61
  tokenizer = AutoTokenizer.from_pretrained(path)
62
+ tokenizer.pad_token = self.tokenizer.eos_token
 
63
  self.tokenizer = tokenizer
64
+ # Initialize the pipeline for text generation
65
+ self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0) # device=0 for CUDA
66
 
67
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
68
  """
69
+ data args:
70
+ inputs (:obj: `str`)
71
+ kwargs
72
+ Return:
73
+ A :obj:`list` | `dict`: will be serialized and returned
74
  """
 
75
  inputs = data.pop("inputs", data)
76
  additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
77
 
78
+ # Define bad words to avoid in the output
 
 
 
 
79
  bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
80
  bad_words_ids.extend(additional_bad_words_ids)
81
 
82
+ # Generate text using the pipeline
83
+ generation_kwargs = {
84
+ "max_length": 75, # Adjust as needed
85
+ "temperature": 1,
86
+ "top_k": 40,
87
+ "bad_words_ids": bad_words_ids,
88
+ "pad_token_id": self.tokenizer.eos_token_id # Ensure padding with EOS token
89
+ }
90
+ generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)
91
 
92
+ # Format the output
93
+ predictions = [{"generated_text": output["generated_text"]} for output in generated_outputs]
94
+ return predictions
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  class StopAtPeriodCriteria(StoppingCriteria):