improve code readability

#1
Files changed (1) hide show
  1. train.py +101 -239
train.py CHANGED
@@ -11,27 +11,43 @@ from tokenizers import ByteLevelBPETokenizer
11
  from huggingface_hub import HfApi
12
  from torch.utils.data import DataLoader
13
  from itertools import islice
14
-
15
- BATCH_SIZE = 16
16
- EPOCHS = 3
17
- LEARNING_RATE = 2e-4
18
- FACTOR = 12 ** 3 // 3
19
- MAX_SEQ_LENGTH = 512
20
- VOCAB_SIZE = 32000
21
- INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
22
- INSTRUCT_DATASET = "nroggendorff/elephant"
23
- OUTPUT_REPO = "nroggendorff/smallama"
24
- INSTRUCT_FINETUNE_BOOL = False
25
- INIT = 0
26
- SHARD_SIZE = int(2e+5)
27
- FP16 = True
28
- WEIGHT_DECAY = 1e-3
29
- GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // 4
30
-
31
- PUSH_TO_HUB = True
32
-
33
- total_steps = (SHARD_SIZE * EPOCHS) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
34
- WARMUP_STEPS = total_steps * 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  class Space:
37
  def __init__(self):
@@ -41,68 +57,42 @@ class Space:
41
  space = Space()
42
 
43
  class FineError(Exception):
44
- def __init__(self, message="Script execution has completed."):
45
  self.message = message
46
  super().__init__(self.message)
47
 
48
- def load_data():
49
- if not INSTRUCT_FINETUNE_BOOL:
50
- dataset = load_dataset(INPUT_DATASET, "cosmopedia-v2", split="train", streaming=True)
51
- else:
52
- dataset = load_dataset(INSTRUCT_DATASET, split="train", streaming=True)
53
-
54
- start = INIT * SHARD_SIZE
55
- data_list = list(islice(dataset, start, start + SHARD_SIZE))
56
-
57
- dataset = Dataset.from_dict({'text': [example['text'] for example in data_list]})
58
- return dataset
59
-
60
- def encode_decode(texts, tok):
61
- if tok.pad_token is None:
62
- tok.pad_token = tok.eos_token
63
-
64
- tokenized_texts = tok(
65
- texts,
66
- padding="max_length",
67
- truncation=True,
68
- max_length=MAX_SEQ_LENGTH,
69
- return_tensors="pt"
70
  ).input_ids
71
-
72
- if tokenized_texts.dim() >= 1:
73
- decoded_texts = tok.batch_decode(tokenized_texts)
74
- else:
75
- print('Found invalid entry in examples. Returning dummy..')
76
- decoded_texts = [tokenizer.pad_token * MAX_SEQ_LENGTH]
77
-
78
- islist = not len(decoded_texts) == 1
79
-
80
- return decoded_texts if islist else decoded_texts[0]
81
 
82
  def create_tokenizer(training_corpus):
83
  tokenizer = ByteLevelBPETokenizer()
84
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
85
- tokenizer.train_from_iterator(
86
- training_corpus,
87
- vocab_size=VOCAB_SIZE,
88
- min_frequency=2,
89
- special_tokens=special_tokens
90
- )
91
- fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
92
- return fast_tokenizer
93
 
94
- def load_tokenizer():
95
- return AutoTokenizer.from_pretrained(OUTPUT_REPO + '-it' if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO)
96
 
97
  def get_training_corpus(dataset):
98
  for i in range(0, len(dataset['text']), 1000):
99
  yield dataset['text'][i : i + 1000]
100
 
101
- def format_prompts(examples, tokenizer, isinst):
102
  texts = []
103
  for text in examples['text']:
104
  if text and len(text.strip()) > 0:
105
- if isinst:
106
  conversation = []
107
  parts = text.split('<|end|>')
108
  for i in range(0, len(parts) - 1, 2):
@@ -110,29 +100,22 @@ def format_prompts(examples, tokenizer, isinst):
110
  response = parts[i + 1].replace("<|bot|>", "").strip()
111
  conversation.append({"role": "user", "content": prompt})
112
  conversation.append({"role": "assistant", "content": response})
113
- formatted_conversation = tokenizer.apply_chat_template(conversation, tokenize=False)
114
- coded_text = tokenizer.code(formatted_conversation)
115
  texts.append(coded_text)
116
  else:
117
  texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
118
- else:
119
- print('Found empty entry in examples. Moving on..')
120
- continue
121
-
122
- if len(texts) == 0:
123
  raise ValueError("No valid texts found in examples for formatting.")
124
-
125
- coded_texts = tokenizer.code(texts)
126
- return {'text': coded_texts}
127
 
128
  def create_model(tokenizer):
129
  config = LlamaConfig(
130
  vocab_size=tokenizer.vocab_size,
131
- hidden_size=FACTOR,
132
- intermediate_size=FACTOR * 4,
133
  num_hidden_layers=12,
134
  num_attention_heads=12,
135
- max_position_embeddings=MAX_SEQ_LENGTH,
136
  rms_norm_eps=1e-5,
137
  initializer_range=0.02,
138
  use_cache=True,
@@ -143,175 +126,54 @@ def create_model(tokenizer):
143
  )
144
  return LlamaForCausalLM(config)
145
 
146
- def load_model():
147
- return AutoModelForCausalLM.from_pretrained(OUTPUT_REPO + '-it' if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO)
148
-
149
- def configure_tokenizer(tokenizer):
150
- special_tokens = {
151
- "bos_token": "<s>",
152
- "eos_token": "</s>",
153
- "unk_token": "<unk>",
154
- "pad_token": "<pad>",
155
- "mask_token": "<mask>",
156
- "additional_special_tokens": []
157
- }
158
- if INSTRUCT_FINETUNE_BOOL:
159
- special_tokens["additional_special_tokens"] = ["<|user|>", "<|bot|>", "<|end|>"]
160
- tokenizer.add_special_tokens(special_tokens)
161
-
162
- if INSTRUCT_FINETUNE_BOOL:
163
- tokenizer.user_token_id = tokenizer.convert_tokens_to_ids("<|user|>")
164
- tokenizer.assistant_token_id = tokenizer.convert_tokens_to_ids("<|bot|>")
165
-
166
- chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '<|end|>\n' }}{% elif message['role'] == 'assistant' %}{{ '<|bot|>\n' + message['content'] + '<|end|>\n' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
167
- tokenizer.chat_template = chat_template
168
-
169
- tokenizer.code = lambda example: encode_decode(example, tokenizer)
170
-
171
- def update_tokenizer(tokenizer, dataset, batch_size=1000):
172
- existing_vocab = tokenizer.get_vocab()
173
- oov_tokens = set()
174
-
175
- for i in range(0, len(dataset['text']), batch_size):
176
- batch = dataset['text'][i:i + batch_size]
177
-
178
- for text in batch:
179
- token_ids = tokenizer.encode(text, add_special_tokens=False)
180
-
181
- for token_id in token_ids:
182
- token = tokenizer.decode([token_id])
183
- if token.strip() and token not in existing_vocab:
184
- oov_tokens.add(token)
185
-
186
- if oov_tokens:
187
- num_added = tokenizer.add_tokens(list(oov_tokens))
188
- return num_added
189
-
190
- return 0
191
-
192
- def train_model(model, tokenizer, dataset, push, isinst):
193
  args = TrainingArguments(
194
  output_dir="model",
195
- num_train_epochs=EPOCHS,
196
- per_device_train_batch_size=BATCH_SIZE,
197
- learning_rate=LEARNING_RATE,
198
- optim="adamw_torch",
199
- warmup_steps=WARMUP_STEPS,
200
- weight_decay=WEIGHT_DECAY,
201
- gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
202
- fp16=FP16,
203
- save_steps=WARMUP_STEPS * 5,
204
- logging_steps=WARMUP_STEPS,
205
- eval_strategy="no",
206
- report_to="no",
207
- # eval_steps=WARMUP_STEPS,
208
  save_total_limit=2,
 
209
  )
210
-
211
- optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=WEIGHT_DECAY)
212
- scheduler = get_cosine_schedule_with_warmup(
213
- optimizer,
214
- num_warmup_steps=args.warmup_steps,
215
- num_training_steps=total_steps
216
- )
217
-
218
- dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer, isinst), batched=True, remove_columns=dataset.column_names)
219
-
220
- if 'text' not in dataset.column_names:
221
- raise ValueError("Dataset transformation failed: 'text' column missing after mapping.")
222
-
223
- print("Mapped dataset sample length:", len(dataset[0]['text']))
224
-
225
- try:
226
- test_input = tokenizer(
227
- ["This is a test input."],
228
- return_tensors="pt",
229
- padding="max_length",
230
- truncation=True,
231
- max_length=MAX_SEQ_LENGTH
232
- )
233
- test_output = model(**test_input)
234
- print("Model test output shape:", test_output.logits.shape)
235
- except RuntimeError as e:
236
- print(f"Error processing test batch: {e}")
237
-
238
- trainer = trl.SFTTrainer(
239
- model=model,
240
- tokenizer=tokenizer,
241
- args=args,
242
- train_dataset=dataset,
243
- # dataset_text_field='text',
244
- max_seq_length=MAX_SEQ_LENGTH,
245
- optimizers=(optimizer, scheduler)
246
  )
247
-
248
- train = trainer.train()
249
-
250
- trained_model = trainer.model
251
- trained_tokenizer = trainer.tokenizer
252
-
253
- if push:
254
- repo_id = OUTPUT_REPO + "-it" if INSTRUCT_FINETUNE_BOOL else OUTPUT_REPO
255
- msg = f"Training loss: {train.training_loss:.4f}"
256
- trained_model.push_to_hub(repo_id, commit_message=msg, force=True)
257
- trained_tokenizer.push_to_hub(repo_id, commit_message=msg, force=True)
258
- else:
259
- trained_model.save_pretrained("model")
260
- trained_tokenizer.save_pretrained("tokenizer")
261
-
262
- def main(push_to_hub=True, is_inst_finetune=False):
263
- print("Loading Data..")
264
- dataset = load_data()
265
- print("Loaded data.")
266
-
267
- if is_inst_finetune and INIT > 0:
268
- print("Loading Tokenizer..")
269
- tokenizer = load_tokenizer()
270
- print("Loaded Tokenizer.")
271
- else:
272
- print("Making Corpus..")
273
- training_corpus = get_training_corpus(dataset)
274
- print("Made Corpus.")
275
 
276
- print("Making Tokenizer..")
277
- tokenizer = create_tokenizer(training_corpus)
278
- print(f"Made Tokenizer with size {len(tokenizer)}.")
279
-
280
- # print("Adding Tokens..")
281
- # num_new_tokens = update_tokenizer(tokenizer, dataset)
282
- # print(f"Added {num_new_tokens} new tokens to the vocabulary")
283
-
284
- if INIT == 0:
285
- print("Adding Special Tokens..")
286
- configure_tokenizer(tokenizer)
287
- print("Added Tokens.")
288
-
289
- if is_inst_finetune or INIT > 0:
290
- print("Loading Model..")
291
- model = load_model()
292
- print("Loaded Model.")
293
  else:
294
- print("Creating Model..")
295
- model = create_model(tokenizer)
296
- print("Created Model.")
297
-
298
- print(f"Tokenizer vocabulary size: {len(tokenizer)}")
299
- print(f"Special tokens: {tokenizer.special_tokens_map}")
300
-
301
- print("Resizing Token Embeddings..")
302
- try:
303
- model.resize_token_embeddings(len(tokenizer))
304
- except RuntimeError as e:
305
- raise RuntimeError(f"Error resizing token embeddings: {e}")
306
- print("Resized Embeddings.")
307
-
308
- print("Training Model..")
309
- train_model(model, tokenizer, dataset, push_to_hub, is_inst_finetune)
310
- raise FineError("Trained Model.")
311
 
312
  if __name__ == "__main__":
313
  try:
314
- main(PUSH_TO_HUB, INSTRUCT_FINETUNE_BOOL)
315
  except Exception as e:
316
- print(f'{type(e).__name__}: {e}')
317
  space.pause()
 
11
  from huggingface_hub import HfApi
12
  from torch.utils.data import DataLoader
13
  from itertools import islice
14
+ from typing import Optional
15
+ from logging import getLogger, StreamHandler, INFO
16
+
17
+ # Logger setup
18
+ logger = getLogger(__name__)
19
+ logger.setLevel(INFO)
20
+ handler = StreamHandler()
21
+ logger.addHandler(handler)
22
+
23
+ class Config:
24
+ # Model and training hyperparameters
25
+ BATCH_SIZE = 16
26
+ EPOCHS = 3
27
+ LEARNING_RATE = 2e-4
28
+ MAX_SEQ_LENGTH = 512
29
+ VOCAB_SIZE = 32000
30
+ FP16 = True
31
+ WEIGHT_DECAY = 1e-3
32
+ GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // 4
33
+
34
+ # Dataset configurations
35
+ INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
36
+ INSTRUCT_DATASET = "nroggendorff/elephant"
37
+ SHARD_SIZE = int(2e+5)
38
+
39
+ # Output and repo settings
40
+ OUTPUT_REPO = "nroggendorff/smallama"
41
+ PUSH_TO_HUB = True
42
+ INSTRUCT_FINETUNE_BOOL = False
43
+
44
+ # Training steps and warmup
45
+ FACTOR = 12 ** 3 // 3
46
+ TOTAL_STEPS = (SHARD_SIZE * EPOCHS) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
47
+ WARMUP_STEPS = int(TOTAL_STEPS * 0.1)
48
+
49
+ # Initial state for shard offset
50
+ INIT = 0
51
 
52
  class Space:
53
  def __init__(self):
 
57
  space = Space()
58
 
59
  class FineError(Exception):
60
+ def __init__(self, message="Training completed successfully."):
61
  self.message = message
62
  super().__init__(self.message)
63
 
64
+ def load_data(dataset_name: str, split: str, shard_size: int, init_offset: int = 0) -> Dataset:
65
+ dataset = load_dataset(dataset_name, split=split, streaming=True)
66
+ shard_start = init_offset * shard_size
67
+ data_list = list(islice(dataset, shard_start, shard_start + shard_size))
68
+ return Dataset.from_dict({'text': [example.get('text', '') for example in data_list]})
69
+
70
+ def encode_decode(texts, tokenizer):
71
+ if tokenizer.pad_token is None:
72
+ tokenizer.pad_token = tokenizer.eos_token
73
+ tokenized_texts = tokenizer(
74
+ texts, padding="max_length", truncation=True, max_length=Config.MAX_SEQ_LENGTH, return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
75
  ).input_ids
76
+ return tokenizer.batch_decode(tokenized_texts) if tokenized_texts.dim() >= 1 else [tokenizer.pad_token * Config.MAX_SEQ_LENGTH]
 
 
 
 
 
 
 
 
 
77
 
78
  def create_tokenizer(training_corpus):
79
  tokenizer = ByteLevelBPETokenizer()
80
  special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
81
+ tokenizer.train_from_iterator(training_corpus, vocab_size=Config.VOCAB_SIZE, min_frequency=2, special_tokens=special_tokens)
82
+ return PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)
 
 
 
 
 
 
83
 
84
+ def load_tokenizer(repo: str):
85
+ return AutoTokenizer.from_pretrained(repo)
86
 
87
  def get_training_corpus(dataset):
88
  for i in range(0, len(dataset['text']), 1000):
89
  yield dataset['text'][i : i + 1000]
90
 
91
+ def format_prompts(examples, tokenizer, is_instructional):
92
  texts = []
93
  for text in examples['text']:
94
  if text and len(text.strip()) > 0:
95
+ if is_instructional:
96
  conversation = []
97
  parts = text.split('<|end|>')
98
  for i in range(0, len(parts) - 1, 2):
 
100
  response = parts[i + 1].replace("<|bot|>", "").strip()
101
  conversation.append({"role": "user", "content": prompt})
102
  conversation.append({"role": "assistant", "content": response})
103
+ coded_text = tokenizer.code(tokenizer.apply_chat_template(conversation, tokenize=False))
 
104
  texts.append(coded_text)
105
  else:
106
  texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
107
+ if not texts:
 
 
 
 
108
  raise ValueError("No valid texts found in examples for formatting.")
109
+ return {'text': tokenizer.code(texts)}
 
 
110
 
111
  def create_model(tokenizer):
112
  config = LlamaConfig(
113
  vocab_size=tokenizer.vocab_size,
114
+ hidden_size=Config.FACTOR,
115
+ intermediate_size=Config.FACTOR * 4,
116
  num_hidden_layers=12,
117
  num_attention_heads=12,
118
+ max_position_embeddings=Config.MAX_SEQ_LENGTH,
119
  rms_norm_eps=1e-5,
120
  initializer_range=0.02,
121
  use_cache=True,
 
126
  )
127
  return LlamaForCausalLM(config)
128
 
129
+ def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  args = TrainingArguments(
131
  output_dir="model",
132
+ num_train_epochs=Config.EPOCHS,
133
+ per_device_train_batch_size=Config.BATCH_SIZE,
134
+ learning_rate=Config.LEARNING_RATE,
135
+ warmup_steps=Config.WARMUP_STEPS,
136
+ weight_decay=Config.WEIGHT_DECAY,
137
+ gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
138
+ fp16=Config.FP16,
139
+ save_steps=int(Config.WARMUP_STEPS * 5),
140
+ logging_steps=int(Config.WARMUP_STEPS),
 
 
 
 
141
  save_total_limit=2,
142
+ report_to="none",
143
  )
144
+ dataset = dataset.map(
145
+ lambda examples: format_prompts(examples, tokenizer, is_instructional),
146
+ batched=True,
147
+ remove_columns=dataset.column_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  )
149
+ trainer = trl.SFTTrainer(model=model, tokenizer=tokenizer, args=args, train_dataset=dataset)
150
+ train_result = trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ if push_to_hub:
153
+ repo_id = Config.OUTPUT_REPO + "-it" if Config.INSTRUCT_FINETUNE_BOOL else Config.OUTPUT_REPO
154
+ trainer.model.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
155
+ trainer.tokenizer.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  else:
157
+ trainer.model.save_pretrained("model")
158
+ trainer.tokenizer.save_pretrained("tokenizer")
159
+
160
+ def main():
161
+ dataset = load_data(Config.INPUT_DATASET, "train", Config.SHARD_SIZE, Config.INIT)
162
+ tokenizer = (
163
+ load_tokenizer(Config.OUTPUT_REPO)
164
+ if Config.INSTRUCT_FINETUNE_BOOL and Config.INIT > 0
165
+ else create_tokenizer(get_training_corpus(dataset))
166
+ )
167
+ model = (
168
+ load_model()
169
+ if Config.INSTRUCT_FINETUNE_BOOL or Config.INIT > 0
170
+ else create_model(tokenizer)
171
+ )
172
+ train_model(model, tokenizer, dataset, Config.PUSH_TO_HUB, Config.INSTRUCT_FINETUNE_BOOL)
 
173
 
174
  if __name__ == "__main__":
175
  try:
176
+ main()
177
  except Exception as e:
178
+ logger.error(f"{type(e).__name__}: {e}")
179
  space.pause()