hardmaru dmayhem93 commited on
Commit
072102d
1 Parent(s): d4332bd

Add stopping criteria to readme example (#2)

Browse files

- Add stopping criteria to readme example (0b0b9ddd9b93fbc1250c64480607cbc819ae5e38)


Co-authored-by: dmayhem93 <[email protected]>

Files changed (1) hide show
  1. README.md +10 -1
README.md CHANGED
@@ -31,6 +31,14 @@ tokenizer = AutoTokenizer.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
31
  model = AutoModelForCausalLM.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
32
  model.half().cuda()
33
 
 
 
 
 
 
 
 
 
34
  system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
35
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
36
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
@@ -38,7 +46,7 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
38
  - StableLM will refuse to participate in anything that could harm a human.
39
  """
40
 
41
- prompt = f"{system_prompt}<|USER|>What's your mood today?"
42
 
43
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
44
  tokens = model.generate(
@@ -46,6 +54,7 @@ tokens = model.generate(
46
  max_new_tokens=64,
47
  temperature=0.7,
48
  do_sample=True,
 
49
  )
50
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
51
  ```
 
31
  model = AutoModelForCausalLM.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
32
  model.half().cuda()
33
 
34
+ class StopOnTokens(StoppingCriteria):
35
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
36
+ stop_ids = [50278, 50279, 50277, 1, 0]
37
+ for stop_id in stop_ids:
38
+ if input_ids[0][-1] == stop_id:
39
+ return True
40
+ return False
41
+
42
  system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
43
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
44
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
 
46
  - StableLM will refuse to participate in anything that could harm a human.
47
  """
48
 
49
+ prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
50
 
51
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
52
  tokens = model.generate(
 
54
  max_new_tokens=64,
55
  temperature=0.7,
56
  do_sample=True,
57
+ stopping_criteria=StoppingCriteriaList([StopOnTokens()])
58
  )
59
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
60
  ```