Add StoppingCriteria to example to stop on one assistant response

#2
by dmayhem93 - opened
Files changed (1) hide show
  1. README.md +14 -2
README.md CHANGED
@@ -24,12 +24,20 @@ datasets:
24
  Get started chatting with `StableLM-Tuned-Alpha` by using the following code snippet:
25
 
26
  ```python
27
- from transformers import AutoModelForCausalLM, AutoTokenizer
28
 
29
  tokenizer = AutoTokenizer.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
30
  model = AutoModelForCausalLM.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
31
  model.half().cuda()
32
 
 
 
 
 
 
 
 
 
33
  system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
34
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
35
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
@@ -37,12 +45,16 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
37
  - StableLM will refuse to participate in anything that could harm a human.
38
  """
39
 
40
- prompt = f"{system_prompt}<|USER|>What's your mood today?"
41
 
42
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
43
  tokens = model.generate(
44
  **inputs,
45
  max_new_tokens=64,
 
 
 
 
46
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
47
  ```
48
 
 
24
  Get started chatting with `StableLM-Tuned-Alpha` by using the following code snippet:
25
 
26
  ```python
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
28
 
29
  tokenizer = AutoTokenizer.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
30
  model = AutoModelForCausalLM.from_pretrained("StabilityAI/stablelm-tuned-alpha-7b")
31
  model.half().cuda()
32
 
33
+ class StopOnTokens(StoppingCriteria):
34
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
35
+ stop_ids = [50278, 50279, 50277, 1, 0]
36
+ for stop_id in stop_ids:
37
+ if input_ids[0][-1] == stop_id:
38
+ return True
39
+ return False
40
+
41
  system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
42
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
43
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
 
45
  - StableLM will refuse to participate in anything that could harm a human.
46
  """
47
 
48
+ prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
49
 
50
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
51
  tokens = model.generate(
52
  **inputs,
53
  max_new_tokens=64,
54
+ temperature=0.7,
55
+ do_sample=True,
56
+ stopping_criteria=StoppingCriteriaList([StopOnTokens()])
57
+ )
58
  print(tokenizer.decode(tokens[0], skip_special_tokens=True))
59
  ```
60