Spaces:
Paused
Paused
nroggendorff
commited on
Commit
•
cbac722
1
Parent(s):
5b65926
if this doesnt work i will be sad
Browse files
prep.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from itertools import islice
|
3 |
+
from datasets import load_dataset, Dataset
|
4 |
+
from tokenizers import ByteLevelBPETokenizer
|
5 |
+
from transformers import PreTrainedTokenizerFast, AutoTokenizer
|
6 |
+
from config import Config
|
7 |
+
|
8 |
+
config = Config()
|
9 |
+
|
10 |
+
def load_data():
|
11 |
+
if not config.INSTRUCT_FINETUNE_BOOL:
|
12 |
+
dataset = load_dataset(config.INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
|
13 |
+
else:
|
14 |
+
dataset = load_dataset(config.INSTRUCT_DATASET, split="train", streaming=True)
|
15 |
+
|
16 |
+
start = config.INIT * config.SHARD_SIZE
|
17 |
+
data_list = list(islice(dataset, start, start + config.SHARD_SIZE))
|
18 |
+
|
19 |
+
dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
|
20 |
+
return dataset
|
21 |
+
|
22 |
+
def create_tokenizer(training_corpus):
|
23 |
+
tokenizer = ByteLevelBPETokenizer()
|
24 |
+
special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
|
25 |
+
tokenizer.train_from_iterator(
|
26 |
+
training_corpus,
|
27 |
+
vocab_size=config.VOCAB_SIZE,
|
28 |
+
min_frequency=2,
|
29 |
+
special_tokens=special_tokens
|
30 |
+
)
|
31 |
+
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
|
32 |
+
return fast_tokenizer
|
33 |
+
|
34 |
+
def load_tokenizer():
|
35 |
+
return AutoTokenizer.from_pretrained(config.OUTPUT_REPO + '-it' if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO)
|
36 |
+
|
37 |
+
def get_training_corpus(dataset):
|
38 |
+
for i in range(0, len(dataset['text']), 1000):
|
39 |
+
yield dataset['text'][i : i + 1000]
|
40 |
+
|
41 |
+
def configure_tokenizer(tokenizer):
|
42 |
+
special_tokens = {
|
43 |
+
"bos_token": "<s>",
|
44 |
+
"eos_token": "</s>",
|
45 |
+
"unk_token": "<unk>",
|
46 |
+
"pad_token": "<pad>",
|
47 |
+
"mask_token": "<mask>",
|
48 |
+
"additional_special_tokens": []
|
49 |
+
}
|
50 |
+
if config.INSTRUCT_FINETUNE_BOOL:
|
51 |
+
special_tokens["additional_special_tokens"] = ["<|user|>", "<|bot|>", "<|end|>"]
|
52 |
+
tokenizer.add_special_tokens(special_tokens)
|
53 |
+
|
54 |
+
if config.INSTRUCT_FINETUNE_BOOL:
|
55 |
+
tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
|
56 |
+
tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
|
57 |
+
|
58 |
+
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
59 |
+
tokenizer.chat_template = chat_template
|
60 |
+
|
61 |
+
def save_prepared_data(dataset, tokenizer):
|
62 |
+
dataset.save_to_disk("prepared_dataset")
|
63 |
+
tokenizer.save_pretrained("prepared_tokenizer")
|
64 |
+
|
65 |
+
def main():
|
66 |
+
print("Loading Data..")
|
67 |
+
dataset = load_data()
|
68 |
+
print("Loaded data.")
|
69 |
+
|
70 |
+
print("Making Corpus..")
|
71 |
+
training_corpus = get_training_corpus(dataset)
|
72 |
+
print("Made Corpus.")
|
73 |
+
|
74 |
+
print("Making Tokenizer..")
|
75 |
+
tokenizer = create_tokenizer(training_corpus)
|
76 |
+
print(f"Made Tokenizer with size {len(tokenizer)}.")
|
77 |
+
|
78 |
+
print("Adding Special Tokens..")
|
79 |
+
configure_tokenizer(tokenizer)
|
80 |
+
print("Added Tokens.")
|
81 |
+
|
82 |
+
print("Saving Prepared Data..")
|
83 |
+
save_prepared_data(dataset, tokenizer)
|
84 |
+
print("Prepared data saved.")
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
main()
|