pseudotensor
commited on
Commit
•
a6d8676
1
Parent(s):
d9fa842
Upload h2oai_pipeline.py
Browse files- h2oai_pipeline.py +344 -162
h2oai_pipeline.py
CHANGED
@@ -71,8 +71,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
71 |
# unknown
|
72 |
model_max_length = None
|
73 |
|
|
|
74 |
if model_max_length is not None:
|
75 |
-
num_prompt_tokens = None
|
76 |
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
77 |
# For https://github.com/h2oai/h2ogpt/issues/192
|
78 |
for trial in range(0, 3):
|
@@ -108,10 +108,10 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
108 |
print("Reduced max_new_tokens from %s -> %s" % (
|
109 |
generate_kwargs['max_new_tokens'], max_new_tokens))
|
110 |
generate_kwargs['max_new_tokens'] = max_new_tokens
|
111 |
-
return prompt_text
|
112 |
|
113 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
114 |
-
prompt_text = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
115 |
|
116 |
data_point = dict(context='', instruction=prompt_text, input='')
|
117 |
if self.prompter is not None:
|
@@ -132,7 +132,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
132 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
133 |
sanitize_bot_response=self.sanitize_bot_response)
|
134 |
elif self.bot and self.human:
|
135 |
-
outputs = rec['generated_text'].split(self.bot)[1].
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
@@ -195,83 +195,6 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
195 |
else:
|
196 |
raise ValueError("TF not avaialble.")
|
197 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
198 |
-
import torch
|
199 |
-
from transformers import StoppingCriteria, StoppingCriteriaList
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
204 |
-
|
205 |
-
def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
|
206 |
-
super().__init__()
|
207 |
-
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
208 |
-
self.encounters = encounters
|
209 |
-
self.stops = [stop.to(device) for stop in stops]
|
210 |
-
self.num_stops = [0] * len(stops)
|
211 |
-
self.model_max_length = model_max_length
|
212 |
-
|
213 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
214 |
-
for stopi, stop in enumerate(self.stops):
|
215 |
-
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
216 |
-
self.num_stops[stopi] += 1
|
217 |
-
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
218 |
-
# print("Stopped", flush=True)
|
219 |
-
return True
|
220 |
-
if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
|
221 |
-
# critical limit
|
222 |
-
return True
|
223 |
-
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
224 |
-
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
225 |
-
return False
|
226 |
-
|
227 |
-
|
228 |
-
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
|
229 |
-
# FIXME: prompt_dict unused currently
|
230 |
-
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
231 |
-
if prompt_type == PromptType.human_bot.name:
|
232 |
-
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
233 |
-
# stopping only starts once output is beyond prompt
|
234 |
-
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
235 |
-
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
236 |
-
encounters = [1, 2]
|
237 |
-
elif prompt_type == PromptType.instruct_vicuna.name:
|
238 |
-
# even below is not enough, generic strings and many ways to encode
|
239 |
-
stop_words = [
|
240 |
-
'### Human:',
|
241 |
-
"""
|
242 |
-
### Human:""",
|
243 |
-
"""
|
244 |
-
### Human:
|
245 |
-
""",
|
246 |
-
'### Assistant:',
|
247 |
-
"""
|
248 |
-
### Assistant:""",
|
249 |
-
"""
|
250 |
-
### Assistant:
|
251 |
-
""",
|
252 |
-
]
|
253 |
-
encounters = [1, 2]
|
254 |
-
else:
|
255 |
-
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
256 |
-
stop_words = ['### End']
|
257 |
-
encounters = [1]
|
258 |
-
stop_words_ids = [
|
259 |
-
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
260 |
-
# handle single token case
|
261 |
-
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
262 |
-
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
263 |
-
# avoid padding in front of tokens
|
264 |
-
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
|
265 |
-
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
266 |
-
# handle fake \n added
|
267 |
-
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
268 |
-
# build stopper
|
269 |
-
stopping_criteria = StoppingCriteriaList(
|
270 |
-
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
|
271 |
-
model_max_length=model_max_length)])
|
272 |
-
else:
|
273 |
-
stopping_criteria = StoppingCriteriaList()
|
274 |
-
return stopping_criteria
|
275 |
from enum import Enum
|
276 |
|
277 |
|
@@ -296,6 +219,12 @@ class PromptType(Enum):
|
|
296 |
wizard2 = 16
|
297 |
wizard3 = 17
|
298 |
instruct_simple = 18
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
|
301 |
class DocumentChoices(Enum):
|
@@ -318,9 +247,41 @@ class LangChainMode(Enum):
|
|
318 |
MY_DATA = "MyData"
|
319 |
GITHUB_H2OGPT = "github h2oGPT"
|
320 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
import ast
|
322 |
import time
|
323 |
-
from enums import PromptType # also supports imports from this file from other files
|
324 |
|
325 |
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
326 |
|
@@ -344,23 +305,29 @@ prompt_type_to_model_name = {
|
|
344 |
'mosaicml/mpt-7b-storywriter',
|
345 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
346 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
347 |
-
'
|
348 |
-
'llama', # plain, or need to choose prompt_type for given TheBloke model
|
349 |
-
'gpt4all_llama', # internally handles prompting
|
350 |
],
|
|
|
351 |
'prompt_answer': [
|
352 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
353 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
354 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
355 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
356 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
357 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
358 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
359 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
|
360 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
|
|
|
361 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
|
362 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
|
363 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
],
|
365 |
'instruct': [],
|
366 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
@@ -373,6 +340,7 @@ prompt_type_to_model_name = {
|
|
373 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
374 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
375 |
'h2oai/h2ogpt-research-oasst1-512-30b',
|
|
|
376 |
'h2oai/h2ogpt-oasst1-falcon-40b',
|
377 |
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
378 |
],
|
@@ -385,7 +353,16 @@ prompt_type_to_model_name = {
|
|
385 |
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
386 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
387 |
"instruct_simple": ['JosephusCheung/Guanaco'],
|
|
|
|
|
|
|
|
|
388 |
}
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
391 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
@@ -399,18 +376,29 @@ for p in PromptType:
|
|
399 |
prompt_types.extend([p.name, p.value, str(p.value)])
|
400 |
|
401 |
|
402 |
-
def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False):
|
403 |
prompt_dict_error = ''
|
|
|
|
|
404 |
if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
|
405 |
try:
|
406 |
prompt_dict = ast.literal_eval(prompt_dict)
|
407 |
except BaseException as e:
|
408 |
prompt_dict_error = str(e)
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
promptA = prompt_dict.get('promptA', '')
|
415 |
promptB = prompt_dict('promptB', '')
|
416 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
@@ -418,21 +406,23 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
|
|
418 |
PreResponse = prompt_dict.get('PreResponse', '')
|
419 |
terminate_response = prompt_dict.get('terminate_response', None)
|
420 |
chat_sep = prompt_dict.get('chat_sep', '\n')
|
|
|
421 |
humanstr = prompt_dict.get('humanstr', '')
|
422 |
botstr = prompt_dict.get('botstr', '')
|
423 |
elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
424 |
PromptType.plain.name]:
|
425 |
-
promptA = promptB = PreInstruct = PreInput = PreResponse =
|
426 |
terminate_response = []
|
427 |
-
chat_sep = ''
|
428 |
-
|
429 |
-
|
|
|
430 |
elif prompt_type == 'simple_instruct':
|
431 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
432 |
terminate_response = []
|
433 |
-
chat_sep = '\n'
|
434 |
-
humanstr =
|
435 |
-
botstr =
|
436 |
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
437 |
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
438 |
str(PromptType.instruct_with_end.value),
|
@@ -458,7 +448,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
|
|
458 |
terminate_response = ['### End']
|
459 |
else:
|
460 |
terminate_response = None
|
461 |
-
chat_sep = '\n'
|
462 |
humanstr = PreInstruct
|
463 |
botstr = PreResponse
|
464 |
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
@@ -480,7 +470,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
|
|
480 |
### Response:
|
481 |
"""
|
482 |
terminate_response = None
|
483 |
-
chat_sep = '\n'
|
484 |
humanstr = PreInstruct # first thing human says
|
485 |
botstr = PreResponse # first thing bot says
|
486 |
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
@@ -502,14 +492,14 @@ Current Time: {}
|
|
502 |
|
503 |
"""
|
504 |
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
505 |
-
start =
|
506 |
-
promptB = promptA = '%s%s
|
507 |
|
508 |
-
PreInstruct =
|
509 |
|
510 |
PreInput = None
|
511 |
|
512 |
-
if
|
513 |
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
514 |
PreResponse = bot + ' '
|
515 |
else:
|
@@ -517,10 +507,11 @@ Current Time: {}
|
|
517 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
518 |
PreResponse = bot
|
519 |
|
520 |
-
terminate_response = [
|
521 |
-
chat_sep = '\n'
|
522 |
humanstr = human # tag before human talks
|
523 |
botstr = bot # tag before bot talks
|
|
|
524 |
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
525 |
PromptType.dai_faq.name]:
|
526 |
promptA = ''
|
@@ -536,7 +527,7 @@ Current Time: {}
|
|
536 |
### Driverless AI documentation answer:
|
537 |
"""
|
538 |
terminate_response = ['\n\n']
|
539 |
-
chat_sep = terminate_response
|
540 |
humanstr = PreInstruct
|
541 |
botstr = PreResponse
|
542 |
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
@@ -545,7 +536,7 @@ Current Time: {}
|
|
545 |
PreInstruct = '## Main Text\n\n'
|
546 |
PreResponse = '\n\n## Summary\n\n'
|
547 |
terminate_response = None
|
548 |
-
chat_sep = '\n'
|
549 |
humanstr = PreInstruct
|
550 |
botstr = PreResponse
|
551 |
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
@@ -565,7 +556,7 @@ Current Time: {}
|
|
565 |
"""
|
566 |
terminate_response = [
|
567 |
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
568 |
-
chat_sep = '\n'
|
569 |
humanstr = PreInstruct
|
570 |
botstr = PreResponse
|
571 |
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
@@ -573,33 +564,50 @@ Current Time: {}
|
|
573 |
preprompt = ''
|
574 |
prompt_tokens = "<|prompt|>"
|
575 |
answer_tokens = "<|answer|>"
|
576 |
-
start =
|
577 |
promptB = promptA = '%s%s' % (preprompt, start)
|
578 |
-
PreInstruct =
|
579 |
PreInput = None
|
580 |
PreResponse = answer_tokens
|
581 |
eos = '<|endoftext|>' # neox eos
|
582 |
-
terminate_response = [start, PreResponse, eos]
|
583 |
-
chat_sep = eos
|
584 |
humanstr = prompt_tokens
|
585 |
botstr = answer_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
587 |
PromptType.open_assistant.name]:
|
588 |
# From added_tokens.json
|
589 |
preprompt = ''
|
590 |
prompt_tokens = "<|prompter|>"
|
591 |
answer_tokens = "<|assistant|>"
|
592 |
-
start =
|
593 |
promptB = promptA = '%s%s' % (preprompt, start)
|
594 |
-
PreInstruct =
|
595 |
PreInput = None
|
596 |
PreResponse = answer_tokens
|
597 |
pend = "<|prefix_end|>"
|
598 |
eos = "</s>"
|
599 |
-
terminate_response = [start, PreResponse, pend, eos]
|
600 |
-
chat_sep = eos
|
601 |
humanstr = prompt_tokens
|
602 |
botstr = answer_tokens
|
|
|
|
|
603 |
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
604 |
PromptType.wizard_lm.name]:
|
605 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
@@ -611,7 +619,7 @@ Current Time: {}
|
|
611 |
PreResponse = "\n\n### Response\n"
|
612 |
eos = "</s>"
|
613 |
terminate_response = [PreResponse, eos]
|
614 |
-
chat_sep = eos
|
615 |
humanstr = promptA
|
616 |
botstr = PreResponse
|
617 |
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
@@ -627,13 +635,12 @@ Current Time: {}
|
|
627 |
### Assistant:
|
628 |
"""
|
629 |
terminate_response = [PreResponse]
|
630 |
-
chat_sep = '\n'
|
631 |
humanstr = PreInstruct
|
632 |
botstr = PreResponse
|
633 |
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
634 |
PromptType.instruct_vicuna2.name]:
|
635 |
-
promptA = promptB = "" if not (
|
636 |
-
chat and reduced) else ''
|
637 |
|
638 |
PreInstruct = """
|
639 |
HUMAN:
|
@@ -646,13 +653,12 @@ ASSISTANT:
|
|
646 |
"""
|
647 |
terminate_response = [
|
648 |
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
649 |
-
chat_sep = '\n'
|
650 |
humanstr = PreInstruct
|
651 |
botstr = PreResponse
|
652 |
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
653 |
PromptType.instruct_vicuna3.name]:
|
654 |
-
promptA = promptB = "" if not (
|
655 |
-
chat and reduced) else ''
|
656 |
|
657 |
PreInstruct = """
|
658 |
### User:
|
@@ -665,13 +671,14 @@ ASSISTANT:
|
|
665 |
"""
|
666 |
terminate_response = [
|
667 |
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
668 |
-
chat_sep = '\n'
|
669 |
humanstr = PreInstruct
|
670 |
botstr = PreResponse
|
671 |
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
672 |
PromptType.wizard2.name]:
|
673 |
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
674 |
-
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
|
|
|
675 |
start = ''
|
676 |
promptB = promptA = '%s%s' % (preprompt, start)
|
677 |
PreInstruct = """
|
@@ -682,27 +689,39 @@ ASSISTANT:
|
|
682 |
### Response:
|
683 |
"""
|
684 |
terminate_response = [PreResponse]
|
685 |
-
chat_sep = '\n'
|
686 |
humanstr = PreInstruct
|
687 |
botstr = PreResponse
|
688 |
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
689 |
PromptType.wizard3.name]:
|
690 |
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
691 |
-
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
|
|
692 |
start = ''
|
693 |
promptB = promptA = '%s%s' % (preprompt, start)
|
694 |
PreInstruct = """USER: """
|
695 |
PreInput = None
|
696 |
PreResponse = """ASSISTANT: """
|
697 |
terminate_response = [PreResponse]
|
698 |
-
chat_sep = '\n'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
humanstr = PreInstruct
|
700 |
botstr = PreResponse
|
701 |
|
702 |
elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
|
703 |
PromptType.instruct_simple.name]:
|
704 |
-
promptA = '' if not (chat and reduced) else ''
|
705 |
-
promptB = '' if not (chat and reduced) else ''
|
706 |
|
707 |
PreInstruct = """
|
708 |
### Instruction:
|
@@ -716,21 +735,90 @@ ASSISTANT:
|
|
716 |
### Response:
|
717 |
"""
|
718 |
terminate_response = None
|
719 |
-
chat_sep = '\n'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
720 |
humanstr = PreInstruct
|
721 |
botstr = PreResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
else:
|
723 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
724 |
|
725 |
-
if
|
726 |
-
|
|
|
|
|
727 |
PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
|
728 |
-
|
|
|
|
|
|
|
|
|
|
|
729 |
else:
|
730 |
-
return
|
731 |
|
732 |
|
733 |
-
def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
734 |
context = data_point.get('context')
|
735 |
if context is None:
|
736 |
context = ''
|
@@ -741,9 +829,12 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
|
741 |
prompt_dict = data_point.get('prompt_dict', prompt_dict)
|
742 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
743 |
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
744 |
-
terminate_response, chat_sep, humanstr, botstr
|
|
|
|
|
745 |
|
746 |
-
|
|
|
747 |
|
748 |
if input and promptA:
|
749 |
prompt += f"""{promptA}"""
|
@@ -793,7 +884,7 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
|
793 |
if output:
|
794 |
prompt += f"""{output}"""
|
795 |
|
796 |
-
return prompt, pre_response, terminate_response, chat_sep
|
797 |
|
798 |
|
799 |
def inject_chatsep(prompt_type, prompt, chat_sep=None):
|
@@ -808,9 +899,6 @@ class Prompter(object):
|
|
808 |
allowed_repeat_line_length=10):
|
809 |
self.prompt_type = prompt_type
|
810 |
self.prompt_dict = prompt_dict
|
811 |
-
data_point = dict(instruction='', input='', output='')
|
812 |
-
_, self.pre_response, self.terminate_response, self.chat_sep = \
|
813 |
-
generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False)
|
814 |
self.debug = debug
|
815 |
self.chat = chat
|
816 |
self.stream_output = stream_output
|
@@ -819,15 +907,33 @@ class Prompter(object):
|
|
819 |
self.prompt = None
|
820 |
context = "" # not for chat context
|
821 |
reduced = False # not for chat context
|
|
|
822 |
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
823 |
-
self.terminate_response, self.chat_sep, self.humanstr, self.botstr
|
824 |
-
|
|
|
|
|
825 |
|
826 |
-
def generate_prompt(self, data_point):
|
827 |
-
|
828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
if self.debug:
|
830 |
print("prompt: %s" % prompt, flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
831 |
self.prompt = prompt
|
832 |
return prompt
|
833 |
|
@@ -846,7 +952,8 @@ class Prompter(object):
|
|
846 |
if sanitize_bot_response:
|
847 |
from better_profanity import profanity
|
848 |
response = profanity.censor(response)
|
849 |
-
response
|
|
|
850 |
return response
|
851 |
|
852 |
def clean_repeats(response):
|
@@ -868,12 +975,12 @@ class Prompter(object):
|
|
868 |
# then use most basic parsing like pipeline
|
869 |
if self.botstr in output:
|
870 |
if self.humanstr:
|
871 |
-
output = clean_response(output.split(self.botstr)[1].
|
872 |
else:
|
873 |
# i.e. use after bot but only up to next bot
|
874 |
-
output = clean_response(output.split(self.botstr)[1].
|
875 |
else:
|
876 |
-
# output = clean_response(output
|
877 |
# assume just not printed yet
|
878 |
output = ""
|
879 |
else:
|
@@ -900,9 +1007,9 @@ class Prompter(object):
|
|
900 |
allow_terminate = True
|
901 |
output = output[len(prompt):]
|
902 |
# clean after subtract prompt out, so correct removal of pre_response
|
903 |
-
output = clean_response(output)
|
904 |
if self.repeat_penalty:
|
905 |
-
output = clean_repeats(output)
|
906 |
if self.terminate_response and allow_terminate:
|
907 |
finds = []
|
908 |
for term in self.terminate_response:
|
@@ -910,11 +1017,9 @@ class Prompter(object):
|
|
910 |
finds = [x for x in finds if x >= 0]
|
911 |
if len(finds) > 0:
|
912 |
termi = finds[0]
|
913 |
-
output = output[:termi]
|
914 |
else:
|
915 |
-
output = output
|
916 |
-
else:
|
917 |
-
output = output.strip()
|
918 |
if multi_output:
|
919 |
# prefix with output counter
|
920 |
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
@@ -927,3 +1032,80 @@ class Prompter(object):
|
|
927 |
if self.debug:
|
928 |
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
929 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# unknown
|
72 |
model_max_length = None
|
73 |
|
74 |
+
num_prompt_tokens = None
|
75 |
if model_max_length is not None:
|
|
|
76 |
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
77 |
# For https://github.com/h2oai/h2ogpt/issues/192
|
78 |
for trial in range(0, 3):
|
|
|
108 |
print("Reduced max_new_tokens from %s -> %s" % (
|
109 |
generate_kwargs['max_new_tokens'], max_new_tokens))
|
110 |
generate_kwargs['max_new_tokens'] = max_new_tokens
|
111 |
+
return prompt_text, num_prompt_tokens
|
112 |
|
113 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
114 |
+
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
115 |
|
116 |
data_point = dict(context='', instruction=prompt_text, input='')
|
117 |
if self.prompter is not None:
|
|
|
132 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
133 |
sanitize_bot_response=self.sanitize_bot_response)
|
134 |
elif self.bot and self.human:
|
135 |
+
outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
|
|
195 |
else:
|
196 |
raise ValueError("TF not avaialble.")
|
197 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
from enum import Enum
|
199 |
|
200 |
|
|
|
219 |
wizard2 = 16
|
220 |
wizard3 = 17
|
221 |
instruct_simple = 18
|
222 |
+
wizard_vicuna = 19
|
223 |
+
openai = 20
|
224 |
+
openai_chat = 21
|
225 |
+
gptj = 22
|
226 |
+
prompt_answer_openllama = 23
|
227 |
+
vicuna11 = 24
|
228 |
|
229 |
|
230 |
class DocumentChoices(Enum):
|
|
|
247 |
MY_DATA = "MyData"
|
248 |
GITHUB_H2OGPT = "github h2oGPT"
|
249 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
250 |
+
|
251 |
+
|
252 |
+
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
253 |
+
|
254 |
+
|
255 |
+
# from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
|
256 |
+
model_token_mapping = {
|
257 |
+
"gpt-4": 8192,
|
258 |
+
"gpt-4-0314": 8192,
|
259 |
+
"gpt-4-32k": 32768,
|
260 |
+
"gpt-4-32k-0314": 32768,
|
261 |
+
"gpt-3.5-turbo": 4096,
|
262 |
+
"gpt-3.5-turbo-16k": 16*1024,
|
263 |
+
"gpt-3.5-turbo-0301": 4096,
|
264 |
+
"text-ada-001": 2049,
|
265 |
+
"ada": 2049,
|
266 |
+
"text-babbage-001": 2040,
|
267 |
+
"babbage": 2049,
|
268 |
+
"text-curie-001": 2049,
|
269 |
+
"curie": 2049,
|
270 |
+
"davinci": 2049,
|
271 |
+
"text-davinci-003": 4097,
|
272 |
+
"text-davinci-002": 4097,
|
273 |
+
"code-davinci-002": 8001,
|
274 |
+
"code-davinci-001": 8001,
|
275 |
+
"code-cushman-002": 2048,
|
276 |
+
"code-cushman-001": 2048,
|
277 |
+
}
|
278 |
+
|
279 |
+
|
280 |
+
source_prefix = "Sources [Score | Link]:"
|
281 |
+
source_postfix = "End Sources<p>"
|
282 |
+
import os
|
283 |
import ast
|
284 |
import time
|
|
|
285 |
|
286 |
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
287 |
|
|
|
305 |
'mosaicml/mpt-7b-storywriter',
|
306 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
307 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
308 |
+
'mosaicml/mpt-30b-instruct', # internal code handles instruct
|
|
|
|
|
309 |
],
|
310 |
+
'gptj': ['gptj', 'gpt4all_llama'],
|
311 |
'prompt_answer': [
|
312 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
313 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
314 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
|
|
|
|
|
|
|
|
315 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
|
316 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
|
317 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
|
318 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
|
319 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
|
320 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
|
321 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
|
322 |
+
'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
|
323 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
|
324 |
+
],
|
325 |
+
'prompt_answer_openllama': [
|
326 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
327 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
328 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
329 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
330 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
|
331 |
],
|
332 |
'instruct': [],
|
333 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
|
|
340 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
341 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
342 |
'h2oai/h2ogpt-research-oasst1-512-30b',
|
343 |
+
'h2oai/h2ogpt-research-oasst1-llama-65b',
|
344 |
'h2oai/h2ogpt-oasst1-falcon-40b',
|
345 |
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
346 |
],
|
|
|
353 |
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
354 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
355 |
"instruct_simple": ['JosephusCheung/Guanaco'],
|
356 |
+
"wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
|
357 |
+
"wizard2": ['llama', 'mosaicml/mpt-30b-instruct'],
|
358 |
+
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
359 |
+
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
360 |
}
|
361 |
+
if os.getenv('OPENAI_API_KEY'):
|
362 |
+
prompt_type_to_model_name.update({
|
363 |
+
"openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
|
364 |
+
"openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
|
365 |
+
})
|
366 |
|
367 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
368 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
|
|
376 |
prompt_types.extend([p.name, p.value, str(p.value)])
|
377 |
|
378 |
|
379 |
+
def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
|
380 |
prompt_dict_error = ''
|
381 |
+
generates_leading_space = False
|
382 |
+
|
383 |
if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
|
384 |
try:
|
385 |
prompt_dict = ast.literal_eval(prompt_dict)
|
386 |
except BaseException as e:
|
387 |
prompt_dict_error = str(e)
|
388 |
+
if prompt_dict_error:
|
389 |
+
promptA = None
|
390 |
+
promptB = None
|
391 |
+
PreInstruct = None
|
392 |
+
PreInput = ''
|
393 |
+
PreResponse = ''
|
394 |
+
terminate_response = None
|
395 |
+
chat_sep = ''
|
396 |
+
chat_turn_sep = ''
|
397 |
+
humanstr = ''
|
398 |
+
botstr = ''
|
399 |
+
generates_leading_space = False
|
400 |
+
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
401 |
+
PromptType.custom.name]:
|
402 |
promptA = prompt_dict.get('promptA', '')
|
403 |
promptB = prompt_dict('promptB', '')
|
404 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
|
|
406 |
PreResponse = prompt_dict.get('PreResponse', '')
|
407 |
terminate_response = prompt_dict.get('terminate_response', None)
|
408 |
chat_sep = prompt_dict.get('chat_sep', '\n')
|
409 |
+
chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
|
410 |
humanstr = prompt_dict.get('humanstr', '')
|
411 |
botstr = prompt_dict.get('botstr', '')
|
412 |
elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
413 |
PromptType.plain.name]:
|
414 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
415 |
terminate_response = []
|
416 |
+
chat_turn_sep = chat_sep = ''
|
417 |
+
# plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
|
418 |
+
humanstr = None
|
419 |
+
botstr = None
|
420 |
elif prompt_type == 'simple_instruct':
|
421 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
422 |
terminate_response = []
|
423 |
+
chat_turn_sep = chat_sep = '\n'
|
424 |
+
humanstr = None
|
425 |
+
botstr = None
|
426 |
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
427 |
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
428 |
str(PromptType.instruct_with_end.value),
|
|
|
448 |
terminate_response = ['### End']
|
449 |
else:
|
450 |
terminate_response = None
|
451 |
+
chat_turn_sep = chat_sep = '\n'
|
452 |
humanstr = PreInstruct
|
453 |
botstr = PreResponse
|
454 |
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
|
|
470 |
### Response:
|
471 |
"""
|
472 |
terminate_response = None
|
473 |
+
chat_turn_sep = chat_sep = '\n'
|
474 |
humanstr = PreInstruct # first thing human says
|
475 |
botstr = PreResponse # first thing bot says
|
476 |
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
|
|
492 |
|
493 |
"""
|
494 |
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
495 |
+
start = ''
|
496 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
497 |
|
498 |
+
PreInstruct = human + ' '
|
499 |
|
500 |
PreInput = None
|
501 |
|
502 |
+
if making_context:
|
503 |
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
504 |
PreResponse = bot + ' '
|
505 |
else:
|
|
|
507 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
508 |
PreResponse = bot
|
509 |
|
510 |
+
terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
|
511 |
+
chat_turn_sep = chat_sep = '\n'
|
512 |
humanstr = human # tag before human talks
|
513 |
botstr = bot # tag before bot talks
|
514 |
+
generates_leading_space = True
|
515 |
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
516 |
PromptType.dai_faq.name]:
|
517 |
promptA = ''
|
|
|
527 |
### Driverless AI documentation answer:
|
528 |
"""
|
529 |
terminate_response = ['\n\n']
|
530 |
+
chat_turn_sep = chat_sep = terminate_response
|
531 |
humanstr = PreInstruct
|
532 |
botstr = PreResponse
|
533 |
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
|
|
536 |
PreInstruct = '## Main Text\n\n'
|
537 |
PreResponse = '\n\n## Summary\n\n'
|
538 |
terminate_response = None
|
539 |
+
chat_turn_sep = chat_sep = '\n'
|
540 |
humanstr = PreInstruct
|
541 |
botstr = PreResponse
|
542 |
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
|
|
556 |
"""
|
557 |
terminate_response = [
|
558 |
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
559 |
+
chat_turn_sep = chat_sep = '\n'
|
560 |
humanstr = PreInstruct
|
561 |
botstr = PreResponse
|
562 |
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
|
|
564 |
preprompt = ''
|
565 |
prompt_tokens = "<|prompt|>"
|
566 |
answer_tokens = "<|answer|>"
|
567 |
+
start = ''
|
568 |
promptB = promptA = '%s%s' % (preprompt, start)
|
569 |
+
PreInstruct = prompt_tokens
|
570 |
PreInput = None
|
571 |
PreResponse = answer_tokens
|
572 |
eos = '<|endoftext|>' # neox eos
|
|
|
|
|
573 |
humanstr = prompt_tokens
|
574 |
botstr = answer_tokens
|
575 |
+
terminate_response = [humanstr, PreResponse, eos]
|
576 |
+
chat_sep = ''
|
577 |
+
chat_turn_sep = eos
|
578 |
+
elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
|
579 |
+
PromptType.prompt_answer_openllama.name]:
|
580 |
+
preprompt = ''
|
581 |
+
prompt_tokens = "<|prompt|>"
|
582 |
+
answer_tokens = "<|answer|>"
|
583 |
+
start = ''
|
584 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
585 |
+
PreInstruct = prompt_tokens
|
586 |
+
PreInput = None
|
587 |
+
PreResponse = answer_tokens
|
588 |
+
eos = '</s>' # llama eos
|
589 |
+
humanstr = prompt_tokens
|
590 |
+
botstr = answer_tokens
|
591 |
+
terminate_response = [humanstr, PreResponse, eos]
|
592 |
+
chat_sep = ''
|
593 |
+
chat_turn_sep = eos
|
594 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
595 |
PromptType.open_assistant.name]:
|
596 |
# From added_tokens.json
|
597 |
preprompt = ''
|
598 |
prompt_tokens = "<|prompter|>"
|
599 |
answer_tokens = "<|assistant|>"
|
600 |
+
start = ''
|
601 |
promptB = promptA = '%s%s' % (preprompt, start)
|
602 |
+
PreInstruct = prompt_tokens
|
603 |
PreInput = None
|
604 |
PreResponse = answer_tokens
|
605 |
pend = "<|prefix_end|>"
|
606 |
eos = "</s>"
|
|
|
|
|
607 |
humanstr = prompt_tokens
|
608 |
botstr = answer_tokens
|
609 |
+
terminate_response = [humanstr, PreResponse, pend, eos]
|
610 |
+
chat_turn_sep = chat_sep = eos
|
611 |
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
612 |
PromptType.wizard_lm.name]:
|
613 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
|
|
619 |
PreResponse = "\n\n### Response\n"
|
620 |
eos = "</s>"
|
621 |
terminate_response = [PreResponse, eos]
|
622 |
+
chat_turn_sep = chat_sep = eos
|
623 |
humanstr = promptA
|
624 |
botstr = PreResponse
|
625 |
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
|
|
635 |
### Assistant:
|
636 |
"""
|
637 |
terminate_response = [PreResponse]
|
638 |
+
chat_turn_sep = chat_sep = '\n'
|
639 |
humanstr = PreInstruct
|
640 |
botstr = PreResponse
|
641 |
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
642 |
PromptType.instruct_vicuna2.name]:
|
643 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
|
|
644 |
|
645 |
PreInstruct = """
|
646 |
HUMAN:
|
|
|
653 |
"""
|
654 |
terminate_response = [
|
655 |
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
656 |
+
chat_turn_sep = chat_sep = '\n'
|
657 |
humanstr = PreInstruct
|
658 |
botstr = PreResponse
|
659 |
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
660 |
PromptType.instruct_vicuna3.name]:
|
661 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
|
|
662 |
|
663 |
PreInstruct = """
|
664 |
### User:
|
|
|
671 |
"""
|
672 |
terminate_response = [
|
673 |
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
674 |
+
chat_turn_sep = chat_sep = '\n'
|
675 |
humanstr = PreInstruct
|
676 |
botstr = PreResponse
|
677 |
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
678 |
PromptType.wizard2.name]:
|
679 |
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
680 |
+
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
|
681 |
+
chat and reduced) else ''
|
682 |
start = ''
|
683 |
promptB = promptA = '%s%s' % (preprompt, start)
|
684 |
PreInstruct = """
|
|
|
689 |
### Response:
|
690 |
"""
|
691 |
terminate_response = [PreResponse]
|
692 |
+
chat_turn_sep = chat_sep = '\n'
|
693 |
humanstr = PreInstruct
|
694 |
botstr = PreResponse
|
695 |
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
696 |
PromptType.wizard3.name]:
|
697 |
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
698 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
|
699 |
+
chat and reduced) else ''
|
700 |
start = ''
|
701 |
promptB = promptA = '%s%s' % (preprompt, start)
|
702 |
PreInstruct = """USER: """
|
703 |
PreInput = None
|
704 |
PreResponse = """ASSISTANT: """
|
705 |
terminate_response = [PreResponse]
|
706 |
+
chat_turn_sep = chat_sep = '\n'
|
707 |
+
humanstr = PreInstruct
|
708 |
+
botstr = PreResponse
|
709 |
+
elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
|
710 |
+
PromptType.wizard_vicuna.name]:
|
711 |
+
preprompt = ''
|
712 |
+
start = ''
|
713 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
714 |
+
PreInstruct = """USER: """
|
715 |
+
PreInput = None
|
716 |
+
PreResponse = """ASSISTANT: """
|
717 |
+
terminate_response = [PreResponse]
|
718 |
+
chat_turn_sep = chat_sep = '\n'
|
719 |
humanstr = PreInstruct
|
720 |
botstr = PreResponse
|
721 |
|
722 |
elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
|
723 |
PromptType.instruct_simple.name]:
|
724 |
+
promptB = promptA = '' if not (chat and reduced) else ''
|
|
|
725 |
|
726 |
PreInstruct = """
|
727 |
### Instruction:
|
|
|
735 |
### Response:
|
736 |
"""
|
737 |
terminate_response = None
|
738 |
+
chat_turn_sep = chat_sep = '\n'
|
739 |
+
humanstr = PreInstruct
|
740 |
+
botstr = PreResponse
|
741 |
+
elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
|
742 |
+
PromptType.openai.name]:
|
743 |
+
preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
|
744 |
+
chat and reduced) else ''
|
745 |
+
start = ''
|
746 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
747 |
+
PreInstruct = "\nHuman: "
|
748 |
+
PreInput = None
|
749 |
+
PreResponse = "\nAI:"
|
750 |
+
terminate_response = [PreResponse] + [" Human:", " AI:"]
|
751 |
+
chat_turn_sep = chat_sep = '\n'
|
752 |
+
humanstr = PreInstruct
|
753 |
+
botstr = PreResponse
|
754 |
+
elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
|
755 |
+
PromptType.gptj.name]:
|
756 |
+
preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
|
757 |
+
chat and reduced) else ''
|
758 |
+
start = ''
|
759 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
760 |
+
PreInstruct = "\n### Prompt: "
|
761 |
+
PreInput = None
|
762 |
+
PreResponse = "\n### Response: "
|
763 |
+
terminate_response = [PreResponse] + ["Prompt:", "Response:"]
|
764 |
+
chat_turn_sep = chat_sep = '\n'
|
765 |
humanstr = PreInstruct
|
766 |
botstr = PreResponse
|
767 |
+
elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
|
768 |
+
PromptType.openai_chat.name]:
|
769 |
+
# prompting and termination all handled by endpoint
|
770 |
+
preprompt = """"""
|
771 |
+
start = ''
|
772 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
773 |
+
PreInstruct = ""
|
774 |
+
PreInput = None
|
775 |
+
PreResponse = ""
|
776 |
+
terminate_response = []
|
777 |
+
chat_turn_sep = chat_sep = '\n'
|
778 |
+
humanstr = None
|
779 |
+
botstr = None
|
780 |
+
elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
|
781 |
+
PromptType.vicuna11.name]:
|
782 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
|
783 |
+
chat and reduced) else ''
|
784 |
+
start = ''
|
785 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
786 |
+
eos = '</s>'
|
787 |
+
PreInstruct = """USER: """
|
788 |
+
PreInput = None
|
789 |
+
PreResponse = """ASSISTANT:"""
|
790 |
+
terminate_response = [PreResponse]
|
791 |
+
chat_sep = ' '
|
792 |
+
chat_turn_sep = eos
|
793 |
+
humanstr = PreInstruct
|
794 |
+
botstr = PreResponse
|
795 |
+
|
796 |
+
if making_context:
|
797 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
798 |
+
PreResponse = PreResponse + ' '
|
799 |
+
else:
|
800 |
+
# normally LLM adds space after this, because was how trained.
|
801 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
802 |
+
PreResponse = PreResponse
|
803 |
else:
|
804 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
805 |
|
806 |
+
if isinstance(terminate_response, (tuple, list)):
|
807 |
+
assert '' not in terminate_response, "Bad terminate_response"
|
808 |
+
|
809 |
+
ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
|
810 |
PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
|
811 |
+
chat_turn_sep=chat_turn_sep,
|
812 |
+
humanstr=humanstr, botstr=botstr,
|
813 |
+
generates_leading_space=generates_leading_space)
|
814 |
+
|
815 |
+
if return_dict:
|
816 |
+
return ret_dict, prompt_dict_error
|
817 |
else:
|
818 |
+
return tuple(list(ret_dict.values()))
|
819 |
|
820 |
|
821 |
+
def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
|
822 |
context = data_point.get('context')
|
823 |
if context is None:
|
824 |
context = ''
|
|
|
829 |
prompt_dict = data_point.get('prompt_dict', prompt_dict)
|
830 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
831 |
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
832 |
+
terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
|
833 |
+
generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
|
834 |
+
context, reduced, making_context)
|
835 |
|
836 |
+
# could avoid if reduce=True, but too complex for parent functions to handle
|
837 |
+
prompt = context
|
838 |
|
839 |
if input and promptA:
|
840 |
prompt += f"""{promptA}"""
|
|
|
884 |
if output:
|
885 |
prompt += f"""{output}"""
|
886 |
|
887 |
+
return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
|
888 |
|
889 |
|
890 |
def inject_chatsep(prompt_type, prompt, chat_sep=None):
|
|
|
899 |
allowed_repeat_line_length=10):
|
900 |
self.prompt_type = prompt_type
|
901 |
self.prompt_dict = prompt_dict
|
|
|
|
|
|
|
902 |
self.debug = debug
|
903 |
self.chat = chat
|
904 |
self.stream_output = stream_output
|
|
|
907 |
self.prompt = None
|
908 |
context = "" # not for chat context
|
909 |
reduced = False # not for chat context
|
910 |
+
making_context = False # not for chat context
|
911 |
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
912 |
+
self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
|
913 |
+
self.generates_leading_space = \
|
914 |
+
get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
|
915 |
+
self.pre_response = self.PreResponse
|
916 |
|
917 |
+
def generate_prompt(self, data_point, reduced=None):
|
918 |
+
"""
|
919 |
+
data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
|
920 |
+
:param data_point:
|
921 |
+
:param reduced:
|
922 |
+
:return:
|
923 |
+
"""
|
924 |
+
reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
|
925 |
+
making_context = False # whether really making final prompt or just generating context
|
926 |
+
prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
|
927 |
+
making_context)
|
928 |
if self.debug:
|
929 |
print("prompt: %s" % prompt, flush=True)
|
930 |
+
# if have context, should have always reduced and only preappend promptA/B here
|
931 |
+
if data_point.get('context'):
|
932 |
+
if data_point.get('input') and self.promptA:
|
933 |
+
prompt = self.promptA + prompt
|
934 |
+
elif self.promptB:
|
935 |
+
prompt = self.promptB + prompt
|
936 |
+
|
937 |
self.prompt = prompt
|
938 |
return prompt
|
939 |
|
|
|
952 |
if sanitize_bot_response:
|
953 |
from better_profanity import profanity
|
954 |
response = profanity.censor(response)
|
955 |
+
if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
|
956 |
+
response = response[1:]
|
957 |
return response
|
958 |
|
959 |
def clean_repeats(response):
|
|
|
975 |
# then use most basic parsing like pipeline
|
976 |
if self.botstr in output:
|
977 |
if self.humanstr:
|
978 |
+
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
979 |
else:
|
980 |
# i.e. use after bot but only up to next bot
|
981 |
+
output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
|
982 |
else:
|
983 |
+
# output = clean_response(output)
|
984 |
# assume just not printed yet
|
985 |
output = ""
|
986 |
else:
|
|
|
1007 |
allow_terminate = True
|
1008 |
output = output[len(prompt):]
|
1009 |
# clean after subtract prompt out, so correct removal of pre_response
|
1010 |
+
output = clean_response(output)
|
1011 |
if self.repeat_penalty:
|
1012 |
+
output = clean_repeats(output)
|
1013 |
if self.terminate_response and allow_terminate:
|
1014 |
finds = []
|
1015 |
for term in self.terminate_response:
|
|
|
1017 |
finds = [x for x in finds if x >= 0]
|
1018 |
if len(finds) > 0:
|
1019 |
termi = finds[0]
|
1020 |
+
output = output[:termi]
|
1021 |
else:
|
1022 |
+
output = output
|
|
|
|
|
1023 |
if multi_output:
|
1024 |
# prefix with output counter
|
1025 |
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
|
|
1032 |
if self.debug:
|
1033 |
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
1034 |
return output
|
1035 |
+
import torch
|
1036 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
1037 |
+
|
1038 |
+
|
1039 |
+
|
1040 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
1041 |
+
|
1042 |
+
def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
|
1043 |
+
super().__init__()
|
1044 |
+
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
1045 |
+
self.encounters = encounters
|
1046 |
+
self.stops = [stop.to(device) for stop in stops]
|
1047 |
+
self.num_stops = [0] * len(stops)
|
1048 |
+
self.model_max_length = model_max_length
|
1049 |
+
|
1050 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
1051 |
+
for stopi, stop in enumerate(self.stops):
|
1052 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
1053 |
+
self.num_stops[stopi] += 1
|
1054 |
+
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
1055 |
+
# print("Stopped", flush=True)
|
1056 |
+
return True
|
1057 |
+
if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
|
1058 |
+
# critical limit
|
1059 |
+
return True
|
1060 |
+
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
1061 |
+
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
1062 |
+
return False
|
1063 |
+
|
1064 |
+
|
1065 |
+
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
|
1066 |
+
# FIXME: prompt_dict unused currently
|
1067 |
+
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
1068 |
+
if prompt_type == PromptType.human_bot.name:
|
1069 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
1070 |
+
# stopping only starts once output is beyond prompt
|
1071 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
1072 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
1073 |
+
encounters = [1, 2]
|
1074 |
+
elif prompt_type == PromptType.instruct_vicuna.name:
|
1075 |
+
# even below is not enough, generic strings and many ways to encode
|
1076 |
+
stop_words = [
|
1077 |
+
'### Human:',
|
1078 |
+
"""
|
1079 |
+
### Human:""",
|
1080 |
+
"""
|
1081 |
+
### Human:
|
1082 |
+
""",
|
1083 |
+
'### Assistant:',
|
1084 |
+
"""
|
1085 |
+
### Assistant:""",
|
1086 |
+
"""
|
1087 |
+
### Assistant:
|
1088 |
+
""",
|
1089 |
+
]
|
1090 |
+
encounters = [1, 2]
|
1091 |
+
else:
|
1092 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
1093 |
+
stop_words = ['### End']
|
1094 |
+
encounters = [1]
|
1095 |
+
stop_words_ids = [
|
1096 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
1097 |
+
# handle single token case
|
1098 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
1099 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
1100 |
+
# avoid padding in front of tokens
|
1101 |
+
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
|
1102 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
1103 |
+
# handle fake \n added
|
1104 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
1105 |
+
# build stopper
|
1106 |
+
stopping_criteria = StoppingCriteriaList(
|
1107 |
+
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
|
1108 |
+
model_max_length=model_max_length)])
|
1109 |
+
else:
|
1110 |
+
stopping_criteria = StoppingCriteriaList()
|
1111 |
+
return stopping_criteria
|