DavidGF commited on
Commit
e1c643a
1 Parent(s): 7830152

Delete tokenizer_template_switch.py

Browse files
Files changed (1) hide show
  1. tokenizer_template_switch.py +0 -95
tokenizer_template_switch.py DELETED
@@ -1,95 +0,0 @@
1
- import re
2
- from transformers import AutoTokenizer
3
-
4
- def extract_separators(template):
5
- """
6
- Extracts separators used in the tokenization template.
7
- """
8
- # Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
9
- pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
10
- matches = re.findall(pattern, template)
11
- # Clean up any extra spaces and return the matches
12
- separators = [match.strip() for match in matches]
13
-
14
- if any("message['role']" in element for element in separators):
15
- roles = ["system", "user", "assistant"]
16
- separators_ = []
17
- for role in roles:
18
- separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
19
- return separators_
20
-
21
- return separators
22
-
23
- def detect_eos_token(jinja_template, tokenizer):
24
- if "<|im_end|>" in jinja_template:
25
- return "<|im_end|>"
26
- if "</s>" in jinja_template:
27
- return "</s>"
28
- if "eos_token" in jinja_template:
29
- return tokenizer.eos_token
30
- else:
31
- return "<|endoftext|>"
32
-
33
- def recover_messages(formatted_message, separators, eos_token):
34
- """
35
- Recovers the original messages from the formatted message string.
36
- """
37
- # Split the formatted message using the end-of-string token
38
- split_messages = formatted_message.split(eos_token)
39
-
40
- # Remove the last empty string if it exists due to a trailing separator
41
- if split_messages and split_messages[-1].strip() == '':
42
- split_messages.pop()
43
-
44
- # Prepare the list to hold the recovered messages
45
- recovered_messages = []
46
-
47
- # Define roles after the first message, alternating between "user" and "assistant"
48
- alternate_roles = ["user", "assistant"]
49
-
50
- # Iterate over the split messages
51
- for index, message_content in enumerate(split_messages):
52
- # Determine the role, starting with "system" for the first message
53
- # then alternating between "user" and "assistant" for subsequent messages
54
- if index == 0:
55
- role = "system"
56
- else:
57
- role = alternate_roles[(index - 1) % 2]
58
-
59
- # Clean the message content by removing leading/trailing whitespace and separators
60
- clean_content = message_content.strip()
61
- for separator in separators:
62
- clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()
63
-
64
- # Append the cleaned message with its role to the list
65
- recovered_messages.append({"role": role, "content": clean_content})
66
-
67
- return recovered_messages
68
-
69
- def recover_chat_messages(tokenized_chat, tokenizer):
70
- """
71
- Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
72
- """
73
- jinja_template = tokenizer.chat_template
74
- separators = extract_separators(jinja_template)
75
- eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
76
- recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
77
- return recovered_messages
78
-
79
- # Example usage
80
- if __name__ == "__main__":
81
- checkpoint = "Qwen/Qwen1.5-0.5B"
82
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
83
-
84
- messages = [
85
- {
86
- "role": "system",
87
- "content": "You are a friendly chatbot who always responds in the style of a pirate",
88
- },
89
- {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
90
- ]
91
- tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
92
- print(tokenized_chat)
93
-
94
- recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
95
- print(recovered_messages)