File size: 2,554 Bytes
4d82421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import AutoTokenizer
from fastchat.conversation import get_conv_template
import os
from utils import sanitize_jinja2
import difflib

def test_llama2_template():
    jinja_lines = []
    with open("../templates/openhermes-2.5-mistral.jinja2", "r") as f:
        jinja_lines = f.readlines()

    print("jinja_lines: ", jinja_lines)

    print("sanitized: ", sanitize_jinja2(jinja_lines))

    chat = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    {"role": "user", "content": "I'd like to show off how chat templating works!"},
    ]

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="teknium/OpenHermes-2.5-Mistral-7B", trust_remote_code=True)
    # f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
    transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False)
    print("default template")
    print(transformer_prompt)
    # print(tokenizer.chat_template)
    # tokenizer.eos_token = "<|end_of_turn|>"
    tokenizer.chat_template = sanitize_jinja2(jinja_lines)

    transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False)
    print()
    print("add_generation_prompt False:")
    print(transformer_prompt)

    transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    print()
    print("add_generation_prompt True:")
    print(transformer_prompt)
    # transformer_prompt = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
    # print(transformer_prompt)


    print("Fastchat template: ")
    conv = get_conv_template("OpenHermes-2.5-Mistral-7B")

    conv.set_system_message(chat[0]["content"])
    conv.append_message(conv.roles[0], chat[1]["content"])
    conv.append_message(conv.roles[1], chat[2]["content"])
    conv.append_message(conv.roles[0], chat[3]["content"])
    conv.append_message(conv.roles[1], None)
    print(conv.get_prompt())
    matcher = difflib.SequenceMatcher(a=transformer_prompt, b=conv.get_prompt())
    print("Matching Sequences:")
    for match in matcher.get_matching_blocks():
        print("Match             : {}".format(match))
        print("Matching Sequence : {}".format(transformer_prompt[match.a:match.a+match.size]))
    assert transformer_prompt == conv.get_prompt()

if __name__  == "__main__":
    test_llama2_template()