Crystalcareai commited on
Commit
fe73328
1 Parent(s): 6524016

Create optuna.py

Browse files
Files changed (1) hide show
  1. optuna.py +153 -0
optuna.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ import torch
3
+ import random
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
5
+ from datasets import load_dataset
6
+ from trl import SFTTrainer
7
+ import time
8
+
9
+ # Set random seed for reproducibility
10
+ random_seed = 42
11
+ torch.manual_seed(random_seed)
12
+ random.seed(random_seed)
13
+
14
+ # Load dataset
15
+ dataset = load_dataset("tatsu-lab/alpaca", split="train")
16
+
17
+
18
+ def chatml_format(example):
19
+ """Format the dataset for training, accounting for empty columns."""
20
+ return {
21
+ "instruction": example['instruction'] if 'instruction' in example else " \n",
22
+ "input": example['input'] if 'input' in example else " \n",
23
+ "system": example['system'] if 'system' in example else " \n",
24
+ "output": example['output'] if 'output' in example else " \n",
25
+ }
26
+
27
+
28
+ # Format dataset
29
+ dataset = dataset.map(chatml_format, remove_columns=dataset.column_names)
30
+
31
+ # Define the model initialization function
32
+ def model_init(trial=None):
33
+ original = False
34
+ params = {}
35
+ if trial is not None:
36
+ n_ahead = 1
37
+ n_ahead_talk = 1
38
+ n_passes = 1
39
+ gumbel_temperature = 1
40
+ use_start_thought_token = True
41
+ use_end_thought_token = True
42
+ include_policy_loss = True
43
+ gumbel_detach = True
44
+ merged_talk_heads = True
45
+ residual_think_head = False
46
+ optimize_lm_head_only_at_start = False
47
+
48
+ model_id = "Crystalcareai/Quiet-Star-Custom"
49
+ tokenizer_id = model_id
50
+
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_id,
53
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
54
+ max_thoughts=n_ahead + n_ahead_talk + 1,
55
+ merged_talk_heads=merged_talk_heads,
56
+ merged_lm_and_talk_heads=False,
57
+ merged_lm_and_think_heads=True,
58
+ use_concat_talk_head=True,
59
+ use_shallow_think=True,
60
+ use_shallow_talk=False,
61
+ use_complex_think_head=False,
62
+ use_complex_talk_head=True,
63
+ use_weighted_talk_head=True,
64
+ trust_remote_code=True,
65
+ device_map="auto",
66
+ )
67
+
68
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, truncation=True, padding="left")
69
+ tokenizer.pad_token_id = tokenizer.eos_token_id
70
+
71
+ special_tokens_to_add = []
72
+ if model.use_start_thought_token:
73
+ special_tokens_to_add.append("<|startthought|>")
74
+ if model.use_end_thought_token:
75
+ special_tokens_to_add.append("<|endthought|>")
76
+ if special_tokens_to_add:
77
+ tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
78
+ model.resize_token_embeddings(len(tokenizer))
79
+ model.tokenizer = tokenizer
80
+ for name, module in model.named_modules():
81
+ if "embed" in name:
82
+ print(module, flush=True)
83
+
84
+ model.gumbel_detach = gumbel_detach
85
+ model.include_policy_loss = include_policy_loss
86
+ model.use_end_thought_token = use_end_thought_token
87
+ model.use_start_thought_token = use_start_thought_token
88
+ model.n_ahead = n_ahead
89
+ model.n_ahead_talk = n_ahead_talk
90
+ model.n_passes = n_passes
91
+ model.residual_think_head = residual_think_head
92
+ model.gumbel_temperature = gumbel_temperature
93
+ model.original_mode = original
94
+ model.config_params = params
95
+ model.run_start = int(time.time())
96
+ model.train()
97
+ return model
98
+
99
+ # Define the objective function for Optuna
100
+ # Define the objective function for Optuna
101
+ def objective(trial):
102
+ # Hyperparameters to be optimized
103
+ learning_rate = trial.suggest_float("learning_rate", 1e-07, 1e-06, log=True)
104
+ max_grad_norm = trial.suggest_float("max_grad_norm", 0.3, 1.0)
105
+ warmup_steps = trial.suggest_int("warmup_steps", 0, 20)
106
+ gradient_accumulation_steps = trial.suggest_int("gradient_accumulation_steps", 4, 8)
107
+
108
+ model = model_init(trial)
109
+
110
+ training_args = TrainingArguments(
111
+ output_dir="./out",
112
+ num_train_epochs=3,
113
+ max_steps=30,
114
+ per_device_train_batch_size=1,
115
+ logging_steps=1,
116
+ optim="lion_32bit",
117
+ save_strategy="steps",
118
+ save_steps=3000,
119
+ gradient_accumulation_steps=gradient_accumulation_steps,
120
+ learning_rate=learning_rate,
121
+ max_grad_norm=max_grad_norm,
122
+ warmup_steps=warmup_steps,
123
+ lr_scheduler_type="cosine",
124
+ report_to="none" # Disable reporting to avoid errors related to WandB in this context
125
+ )
126
+
127
+ trainer = SFTTrainer(
128
+ args=training_args,
129
+ train_dataset=dataset,
130
+ model=model,
131
+ tokenizer=model.tokenizer,
132
+ max_seq_length=1024,
133
+ dataset_text_field="output",
134
+ )
135
+
136
+ # Train the model and get the training loss
137
+ train_result = trainer.train()
138
+ loss = train_result.training_loss
139
+
140
+ return loss
141
+
142
+
143
+ # Create a study and optimize
144
+ study = optuna.create_study(storage="sqlite:///db.sqlite3")
145
+ study.optimize(objective, n_trials=100)
146
+
147
+ # Print the best trial
148
+ print("Best trial:")
149
+ trial = study.best_trial
150
+ print(f" Loss: {trial.value}")
151
+ print(" Params: ")
152
+ for key, value in trial.params.items():
153
+ print(f" {key}: {value}")