Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer | |
from fastchat.conversation import get_conv_template | |
import os | |
from utils import sanitize_jinja2 | |
def test_llama2_template(): | |
jinja_lines = [] | |
with open("../templates/llama-2.jinja2", "r") as f: | |
jinja_lines = f.readlines() | |
print("jinja_lines: ", jinja_lines) | |
print("sanitized: ", sanitize_jinja2(jinja_lines)) | |
chat = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Hello, how are you?"}, | |
{"role": "assistant", "content": "I'm doing great. How can I help you today?"}, | |
{"role": "user", "content": "I'd like to show off how chat templating works!"}, | |
] | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="microsoft/Orca-2-7b", trust_remote_code=True) | |
# f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" | |
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) | |
print("default template") | |
print(transformer_prompt) | |
# print(tokenizer.chat_template) | |
tokenizer.bos_token = "<s>" | |
tokenizer.eos_token = "</s>" | |
tokenizer.chat_template = sanitize_jinja2(jinja_lines) | |
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False) | |
print() | |
print("add_generation_prompt False:") | |
print(transformer_prompt) | |
transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
print() | |
print("add_generation_prompt True:") | |
print(transformer_prompt) | |
print("Fastchat template: ") | |
conv = get_conv_template("llama-2") | |
conv.set_system_message(chat[0]["content"]) | |
conv.append_message(conv.roles[0], chat[1]["content"]) | |
conv.append_message(conv.roles[1], chat[2]["content"]) | |
conv.append_message(conv.roles[0], chat[3]["content"]) | |
conv.append_message(conv.roles[1], None) | |
print(conv.get_prompt()) | |
assert transformer_prompt == conv.get_prompt() | |
if __name__ == "__main__": | |
test_llama2_template() |