from transformers import AutoTokenizer from fastchat.conversation import get_conv_template import os from utils import sanitize_jinja2 def test_llama2_template(): jinja_lines = [] with open("../templates/llama-2.jinja2", "r") as f: jinja_lines = f.readlines() print("jinja_lines: ", jinja_lines) print("sanitized: ", sanitize_jinja2(jinja_lines)) chat = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, {"role": "user", "content": "I'd like to show off how chat templating works!"}, ] tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="microsoft/Orca-2-7b", trust_remote_code=True) # f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) print("default template") print(transformer_prompt) # print(tokenizer.chat_template) tokenizer.bos_token = "" tokenizer.eos_token = "" tokenizer.chat_template = sanitize_jinja2(jinja_lines) transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) print() print("add_generation_prompt False:") print(transformer_prompt) transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) print() print("add_generation_prompt True:") print(transformer_prompt) print("Fastchat template: ") conv = get_conv_template("llama-2") conv.set_system_message(chat[0]["content"]) conv.append_message(conv.roles[0], chat[1]["content"]) conv.append_message(conv.roles[1], chat[2]["content"]) conv.append_message(conv.roles[0], chat[3]["content"]) conv.append_message(conv.roles[1], None) print(conv.get_prompt()) # assert transformer_prompt == conv.get_prompt() def test_llama2_no_sys_prompt_template(): jinja_lines = [] with open("../templates/llama-2.jinja2", "r") as f: jinja_lines = f.readlines() print("jinja_lines: ", jinja_lines) print("sanitized: ", sanitize_jinja2(jinja_lines)) chat = [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, {"role": "user", "content": "I'd like to show off how chat templating works!"}, ] tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="microsoft/Orca-2-7b", trust_remote_code=True) # f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) print("default template") print(transformer_prompt) # print(tokenizer.chat_template) tokenizer.bos_token = "" tokenizer.eos_token = "" tokenizer.chat_template = sanitize_jinja2(jinja_lines) transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) print() print("add_generation_prompt False:") print(transformer_prompt) transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) print() print("add_generation_prompt True:") print(transformer_prompt) print("Fastchat template: ") conv = get_conv_template("llama-2") # conv.set_system_message(chat[0]["content"]) conv.append_message(conv.roles[0], chat[0]["content"]) conv.append_message(conv.roles[1], chat[1]["content"]) conv.append_message(conv.roles[0], chat[2]["content"]) conv.append_message(conv.roles[1], None) print(conv.get_prompt()) # assert transformer_prompt == conv.get_prompt() if __name__ == "__main__": test_llama2_template() test_llama2_no_sys_prompt_template()