jacobrenn commited on
Commit
c90ac53
1 Parent(s): 761fbc8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -28
README.md CHANGED
@@ -92,33 +92,34 @@ def create_response(
92
  """
93
  Create a response from the model by using a formatted prompt
94
  """
95
- ids = tokenizer(PROMPT_FORMAT.format(instruction = instruction), return_tensors = 'pt').input_ids
96
-
97
- response_id = tokenizer.encode(RESPONSE_KEY)[0]
98
- end_id = tokenizer.encode(END_KEY)[0]
99
-
100
- tokens = model.generate(
101
- ids,
102
- pad_token_id = tokenizer.pad_token_id,
103
- eos_token_id = end_id,
104
- do_sample = do_sample,
105
- max_new_tokens = max_new_tokens,
106
- top_p = top_p,
107
- top_k = top_k,
108
- **kwargs
109
- )[0].cpu()
110
-
111
- res_pos = np.where(tokens == response_id)[0]
112
-
113
- if len(res_pos) == 0:
114
- return None
115
-
116
- res_pos = res_pos[0]
117
- end_pos = np.where(tokens == end_id)[0]
118
- if len(end_pos) > 0:
119
- end_pos = end_pos[0]
120
  else:
121
- end_pos = None
122
-
123
- return tokenizer.decode(tokens[res_pos + 1 : end_pos]).strip()
 
 
 
 
 
124
  ```
 
92
  """
93
  Create a response from the model by using a formatted prompt
94
  """
95
+ input_ids = tokenizer(
96
+ PROMPT.format(instruction=instruction), return_tensors="pt"
97
+ ).input_ids
98
+
99
+ gen_tokens = model.generate(
100
+ input_ids,
101
+ pad_token_id=tokenizer.pad_token_id,
102
+ do_sample=do_sample,
103
+ max_new_tokens=max_new_tokens,
104
+ top_p=top_p,
105
+ top_k=top_k,
106
+ **kwargs,
107
+ )
108
+ decoded = tokenizer.batch_decode(gen_tokens)[0]
109
+
110
+ # The response appears after "### Response:". The model has been trained to append "### End" at the end.
111
+ m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", decoded, flags=re.DOTALL)
112
+
113
+ response = None
114
+ if m:
115
+ response = m.group(1).strip()
 
 
 
 
116
  else:
117
+ # The model might not generate the "### End" sequence before reaching the max tokens. In this case, return
118
+ # everything after "### Response:".
119
+ m = re.search(r"#+\s*Response:\s*(.+)", decoded, flags=re.DOTALL)
120
+ if m:
121
+ response = m.group(1).strip()
122
+ else:
123
+ pass
124
+ return response
125
  ```