Pierce Maloney commited on
Commit
02ffbef
1 Parent(s): 37b54cb

testing custom logit processor for bad words

Browse files
Files changed (3) hide show
  1. __pycache__/handler.cpython-311.pyc +0 -0
  2. handler.py +21 -4
  3. sample.py +1 -1
__pycache__/handler.cpython-311.pyc CHANGED
Binary files a/__pycache__/handler.cpython-311.pyc and b/__pycache__/handler.cpython-311.pyc differ
 
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
3
 
4
 
5
 
@@ -11,6 +11,7 @@ class EndpointHandler():
11
  tokenizer.pad_token = tokenizer.eos_token
12
  self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
13
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
 
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
@@ -28,8 +29,9 @@ class EndpointHandler():
28
  stopping_criteria=self.stopping_criteria,
29
  max_new_tokens=50,
30
  return_full_text=False,
31
- bad_words_ids=[[3070], [313, 334]],
32
- temperature=2,
 
33
  top_k=40,
34
  )
35
  return prediction
@@ -43,4 +45,19 @@ class StopAtPeriodCriteria(StoppingCriteria):
43
  # Decode the last generated token to text
44
  last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
45
  # Check if the decoded text ends with a period
46
- return '.' in last_token_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList
3
 
4
 
5
 
 
11
  tokenizer.pad_token = tokenizer.eos_token
12
  self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
13
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
14
+ self.logits_processor = LogitsProcessorList([BanSpecificTokensLogitsProcessor(tokenizer, [3070])])
15
 
16
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
  """
 
29
  stopping_criteria=self.stopping_criteria,
30
  max_new_tokens=50,
31
  return_full_text=False,
32
+ # bad_words_ids=[[3070], [313, 334]],
33
+ logits_processor=self.logits_processor,
34
+ temperature=1,
35
  top_k=40,
36
  )
37
  return prediction
 
45
  # Decode the last generated token to text
46
  last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
47
  # Check if the decoded text ends with a period
48
+ return '.' in last_token_text
49
+
50
+ class BanSpecificTokensLogitsProcessor(LogitsProcessor):
51
+ """
52
+ Logits processor that sets the logits of specific tokens to -inf,
53
+ effectively banning them from being generated.
54
+ """
55
+ def __init__(self, tokenizer, banned_tokens_ids):
56
+ self.tokenizer = tokenizer
57
+ self.banned_tokens_ids = banned_tokens_ids
58
+
59
+ def __call__(self, input_ids, scores):
60
+ # Set logits of banned tokens to -inf
61
+ for token_id in self.banned_tokens_ids:
62
+ scores[:, token_id] = float('-inf')
63
+ return scores
sample.py CHANGED
@@ -4,7 +4,7 @@ from handler import EndpointHandler
4
  my_handler = EndpointHandler(path=".")
5
 
6
  # prepare sample payload
7
- payload = {"inputs": "I am so happy to reach in my pocket and find a"}
8
 
9
  # test the handler
10
  payload=my_handler(payload)
 
4
  my_handler = EndpointHandler(path=".")
5
 
6
  # prepare sample payload
7
+ payload = {"inputs": "This is the format for a"}
8
 
9
  # test the handler
10
  payload=my_handler(payload)