nroggendorff commited on
Commit
928e52a
1 Parent(s): f52c925

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -128
train.py DELETED
@@ -1,128 +0,0 @@
1
- import os
2
-
3
- import torch
4
- import trl
5
-
6
- from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, TrainingArguments, PreTrainedTokenizerFast
7
- from datasets import load_dataset
8
- from tokenizers import ByteLevelBPETokenizer
9
-
10
- MAX_SEQ_LENGTH = 512
11
- BATCH_SIZE = 768
12
- EPOCHS = 8
13
- LEARNING_RATE = 1e-4
14
- FP16 = True
15
- FACTOR = 2
16
- VOCAB_SIZE = 3200
17
- INPUT_DATASET = "nroggendorff/elephant"
18
- OUTPUT_REPO = "smallama"
19
-
20
- def load_data():
21
- dataset = load_dataset(INPUT_DATASET, split="train")
22
- return dataset
23
-
24
- def create_tokenizer(training_corpus):
25
- tokenizer = ByteLevelBPETokenizer()
26
- tokenizer.train_from_iterator(
27
- training_corpus,
28
- vocab_size=VOCAB_SIZE,
29
- min_frequency=2,
30
- special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>", "<|user|>", "<|bot|>", "<|end|>"]
31
- )
32
-
33
- fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
34
- return fast_tokenizer
35
-
36
- def get_training_corpus(dataset):
37
- for i in range(0, len(dataset), 1000):
38
- yield dataset[i : i + 1000]["text"]
39
-
40
- def format_prompts(examples, tokenizer):
41
- texts = []
42
- for text in examples['text']:
43
- conversation = []
44
- parts = text.split('<|end|>')
45
- for i in range(0, len(parts) - 1, 2):
46
- prompt = parts[i].replace("<|user|>", "")
47
- response = parts[i + 1].replace("<|bot|>", "")
48
- conversation.append({"role": "user", "content": prompt})
49
- conversation.append({"role": "assistant", "content": response})
50
- formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
51
- texts.append(formatted_conversation)
52
- return {"text": texts}
53
-
54
- def create_model(tokenizer):
55
- config = LlamaConfig(
56
- vocab_size=tokenizer.vocab_size,
57
- hidden_size=FACTOR,
58
- intermediate_size=FACTOR * 2,
59
- num_hidden_layers=max(1, FACTOR // 64),
60
- num_attention_heads=max(1, FACTOR // 64),
61
- max_position_embeddings=MAX_SEQ_LENGTH,
62
- rms_norm_eps=1e-6,
63
- initializer_range=0.02,
64
- use_cache=True,
65
- pad_token_id=tokenizer.pad_token_id,
66
- bos_token_id=tokenizer.bos_token_id,
67
- eos_token_id=tokenizer.eos_token_id,
68
- tie_word_embeddings=False,
69
- )
70
-
71
- model = LlamaForCausalLM(config)
72
- return model
73
-
74
- def configure_tokenizer(tokenizer):
75
- special_tokens = {
76
- "bos_token": "<s>",
77
- "eos_token": "</s>",
78
- "unk_token": "<unk>",
79
- "pad_token": "<pad>",
80
- "mask_token": "<mask>",
81
- "additional_special_tokens": ["<|user|>", "<|bot|>", "<|end|>"]
82
- }
83
- tokenizer.add_special_tokens(special_tokens)
84
-
85
- tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
86
- tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
87
-
88
- chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}{{ eos_token }}"
89
- tokenizer.chat_template = chat_template
90
-
91
- def train_model(model, tokenizer, dataset):
92
- args = TrainingArguments(
93
- output_dir="model",
94
- num_train_epochs=EPOCHS,
95
- per_device_train_batch_size=BATCH_SIZE,
96
- learning_rate=LEARNING_RATE,
97
- fp16=FP16,
98
- optim="sgd"
99
- )
100
- dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
101
- trainer = trl.SFTTrainer(
102
- model=model,
103
- tokenizer=tokenizer,
104
- args=args,
105
- train_dataset=dataset,
106
- dataset_text_field='text',
107
- max_seq_length=MAX_SEQ_LENGTH
108
- )
109
- trainer.train()
110
-
111
- trained_model = trainer.model
112
- trained_tokenizer = trainer.tokenizer
113
-
114
- repo_id = OUTPUT_REPO
115
- trained_model.push_to_hub(repo_id)
116
- trained_tokenizer.push_to_hub(repo_id)
117
-
118
- def main():
119
- dataset = load_data()
120
- training_corpus = get_training_corpus(dataset)
121
- tokenizer = create_tokenizer(training_corpus)
122
- configure_tokenizer(tokenizer)
123
- model = create_model(tokenizer)
124
- train_model(model, tokenizer, dataset)
125
-
126
- if __name__ == "__main__":
127
- main()
128
- raise RuntimeError("The script is finished.")