AmelieSchreiber commited on
Commit
d504f0a
1 Parent(s): c0b7d88

Upload 8 files

Browse files
clustered_ppi_train.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig
2
+ from torch.utils.data import Dataset
3
+ import pandas as pd
4
+ import torch
5
+ from torch.optim import AdamW
6
+ import random
7
+ import datetime
8
+
9
+ class ProteinDataset(Dataset):
10
+ def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30):
11
+ self.tokenizer = tokenizer
12
+ self.proteins = proteins
13
+ self.peptides = peptides
14
+ self.mask_percentage = mask_percentage
15
+
16
+ def __len__(self):
17
+ return len(self.proteins)
18
+
19
+ def mask_sequence(self, sequence):
20
+ mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage))
21
+ return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])
22
+
23
+ def __getitem__(self, idx):
24
+ protein_seq = self.proteins[idx]
25
+ peptide_seq = self.peptides[idx]
26
+
27
+ masked_protein = self.mask_sequence(protein_seq)
28
+ masked_peptide = self.mask_sequence(peptide_seq)
29
+ complex_seq = masked_protein + masked_peptide
30
+
31
+ complex_input = self.tokenizer(
32
+ complex_seq,
33
+ return_tensors="pt",
34
+ padding="max_length",
35
+ max_length=1024,
36
+ truncation=True,
37
+ add_special_tokens=False
38
+ )
39
+
40
+ input_ids = complex_input["input_ids"].squeeze()
41
+ attention_mask = complex_input["attention_mask"].squeeze()
42
+
43
+ label_seq = protein_seq + peptide_seq
44
+ labels = self.tokenizer(
45
+ label_seq,
46
+ return_tensors="pt",
47
+ padding="max_length",
48
+ max_length=1024,
49
+ truncation=True,
50
+ add_special_tokens=False
51
+ )["input_ids"].squeeze()
52
+
53
+ labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
54
+
55
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
56
+
57
+ # Callback to update mask percentage after each epoch
58
+ class DynamicMaskingCallback(TrainerCallback):
59
+ def __init__(self, dataset, increment=0.10):
60
+ self.dataset = dataset
61
+ self.increment = increment
62
+
63
+ def on_epoch_end(self, args, state, control, **kwargs):
64
+ self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0)
65
+ print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%")
66
+
67
+ # Loading the dataset
68
+ file_path = "clustered_protein_pair_landscapes_l2_distances.tsv"
69
+ data = pd.read_csv(file_path, delimiter='\t')
70
+
71
+ # Splitting the data based on clusters, starting with cluster 0
72
+ test_clusters = [0] # Start with cluster 0
73
+ remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique()
74
+ random.shuffle(remaining_clusters) # Shuffle the remaining clusters
75
+
76
+ # Determine the size of cluster 0 in the dataset
77
+ cluster_0_size = (data['Cluster'] == 0).mean()
78
+
79
+ # Add more clusters until reaching approximately 20% of the dataset
80
+ test_size = cluster_0_size
81
+ for cluster in remaining_clusters:
82
+ cluster_size = (data['Cluster'] == cluster).mean()
83
+ if test_size + cluster_size > 0.20:
84
+ break
85
+ test_clusters.append(cluster)
86
+ test_size += cluster_size
87
+
88
+ # Creating test and train data based on the selected clusters
89
+ test_data = data[data['Cluster'].isin(test_clusters)]
90
+ train_data = data[~data['Cluster'].isin(test_clusters)]
91
+
92
+ proteins_train = train_data["Protein1"].tolist()
93
+ peptides_train = train_data["Protein2"].tolist()
94
+ proteins_test = test_data["Protein1"].tolist()
95
+ peptides_test = test_data["Protein2"].tolist()
96
+
97
+ # Load tokenizer and model
98
+ model_name = "esm2_t33_650M_UR50D"
99
+ tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)
100
+
101
+ # Load model configuration and modify dropout rates
102
+ config = EsmConfig.from_pretrained("facebook/" + model_name)
103
+ # config.hidden_dropout_prob = 0.1 # Adjust hidden layer dropout
104
+ # config.attention_probs_dropout_prob = 0.1 # Adjust attention dropout
105
+ model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config)
106
+
107
+ # Generate a timestamp for the output directory
108
+ current_time = datetime.datetime.now()
109
+ timestamp = current_time.strftime("%Y%m%d_%H%M%S")
110
+ output_dir = f'./interact_output_{timestamp}/'
111
+
112
+ # Calculate the total number of training steps
113
+ num_train_epochs = 4
114
+ per_device_train_batch_size = 8
115
+ gradient_accumulation_steps = 4
116
+ total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs
117
+
118
+ # Training arguments with cosine learning rate scheduler and gradient clipping
119
+ training_args = TrainingArguments(
120
+ output_dir=output_dir,
121
+ num_train_epochs=num_train_epochs,
122
+ per_device_train_batch_size=per_device_train_batch_size,
123
+ per_device_eval_batch_size=8,
124
+ warmup_steps=10,
125
+ logging_dir='./logs',
126
+ logging_steps=10,
127
+ evaluation_strategy="epoch",
128
+ load_best_model_at_end=True,
129
+ save_strategy='epoch',
130
+ metric_for_best_model='eval_loss',
131
+ save_total_limit=3,
132
+ gradient_accumulation_steps=gradient_accumulation_steps,
133
+ lr_scheduler_type='cosine',
134
+ max_steps=total_steps, # Corrected: Added comma here
135
+ gradient_checkpointing=True, # Enable gradient checkpointing for memory optimization
136
+ max_grad_norm=1.0 # Gradient clipping
137
+ )
138
+
139
+ # Optimizer with added weight decay for regularization
140
+ optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03)
141
+
142
+ # Instantiate the ProteinDataset for training and testing
143
+ train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
144
+ test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)
145
+
146
+ # Initialize DynamicMaskingCallback
147
+ dynamic_masking_callback = DynamicMaskingCallback(train_dataset)
148
+
149
+ # Trainer with callbacks for dynamic masking and gradient clipping
150
+ trainer = Trainer(
151
+ model=model,
152
+ args=training_args,
153
+ train_dataset=train_dataset,
154
+ eval_dataset=test_dataset,
155
+ optimizers=(optimizer, None),
156
+ callbacks=[dynamic_masking_callback]
157
+ )
158
+
159
+ # Start training
160
+ trainer.train()
config (6).json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/esm2_t33_650M_UR50D",
3
+ "architectures": [
4
+ "EsmForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "classifier_dropout": null,
8
+ "emb_layer_norm_before": false,
9
+ "esmfold_config": null,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.0,
12
+ "hidden_size": 1280,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5120,
15
+ "is_folding_model": false,
16
+ "layer_norm_eps": 1e-05,
17
+ "mask_token_id": 32,
18
+ "max_position_embeddings": 1026,
19
+ "model_type": "esm",
20
+ "num_attention_heads": 20,
21
+ "num_hidden_layers": 33,
22
+ "pad_token_id": 1,
23
+ "position_embedding_type": "rotary",
24
+ "token_dropout": true,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.35.2",
27
+ "use_cache": true,
28
+ "vocab_list": null,
29
+ "vocab_size": 33
30
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3199082cbc493bc37c468be23a85815629423af037d1c4b502d772b3bdb5c62c
3
+ size 2609498088
optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60b8758a162379c60e6fccbef2e95cd8fbb1bb183161199bf9406f1980de72d5
3
+ size 5208792737
rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e09137b0c159ffe21c6157d982dbd6a08be3216152087044c92f27f2ce2e7c1b
3
+ size 14511
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb9eb38fd1861c88d0b0104df285795ad50aee4343b358077d4290d5f3f33316
3
+ size 563
trainer_state.json ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 1.1649552583694458,
3
+ "best_model_checkpoint": "./interact_output_20231214_183743/checkpoint-912",
4
+ "epoch": 3.991247264770241,
5
+ "eval_steps": 500,
6
+ "global_step": 912,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.04,
13
+ "learning_rate": 0.0007984276816171436,
14
+ "loss": 1.8072,
15
+ "step": 10
16
+ },
17
+ {
18
+ "epoch": 0.09,
19
+ "learning_rate": 0.0007981855684763583,
20
+ "loss": 1.9931,
21
+ "step": 20
22
+ },
23
+ {
24
+ "epoch": 0.13,
25
+ "learning_rate": 0.0007974595227250475,
26
+ "loss": 2.0219,
27
+ "step": 30
28
+ },
29
+ {
30
+ "epoch": 0.18,
31
+ "learning_rate": 0.0007962504250201388,
32
+ "loss": 1.9413,
33
+ "step": 40
34
+ },
35
+ {
36
+ "epoch": 0.22,
37
+ "learning_rate": 0.000794559741936249,
38
+ "loss": 1.8684,
39
+ "step": 50
40
+ },
41
+ {
42
+ "epoch": 0.26,
43
+ "learning_rate": 0.0007923895241868038,
44
+ "loss": 1.7867,
45
+ "step": 60
46
+ },
47
+ {
48
+ "epoch": 0.31,
49
+ "learning_rate": 0.0007897424041366252,
50
+ "loss": 1.8186,
51
+ "step": 70
52
+ },
53
+ {
54
+ "epoch": 0.35,
55
+ "learning_rate": 0.0007866215926090057,
56
+ "loss": 1.7737,
57
+ "step": 80
58
+ },
59
+ {
60
+ "epoch": 0.39,
61
+ "learning_rate": 0.0007830308749911415,
62
+ "loss": 1.6727,
63
+ "step": 90
64
+ },
65
+ {
66
+ "epoch": 0.44,
67
+ "learning_rate": 0.0007789746066426482,
68
+ "loss": 1.6249,
69
+ "step": 100
70
+ },
71
+ {
72
+ "epoch": 0.48,
73
+ "learning_rate": 0.0007744577076127291,
74
+ "loss": 1.7025,
75
+ "step": 110
76
+ },
77
+ {
78
+ "epoch": 0.53,
79
+ "learning_rate": 0.0007694856566724036,
80
+ "loss": 1.6132,
81
+ "step": 120
82
+ },
83
+ {
84
+ "epoch": 0.57,
85
+ "learning_rate": 0.0007640644846690332,
86
+ "loss": 1.63,
87
+ "step": 130
88
+ },
89
+ {
90
+ "epoch": 0.61,
91
+ "learning_rate": 0.0007582007672112082,
92
+ "loss": 1.6888,
93
+ "step": 140
94
+ },
95
+ {
96
+ "epoch": 0.66,
97
+ "learning_rate": 0.0007519016166928652,
98
+ "loss": 1.6102,
99
+ "step": 150
100
+ },
101
+ {
102
+ "epoch": 0.7,
103
+ "learning_rate": 0.0007451746736663118,
104
+ "loss": 1.5319,
105
+ "step": 160
106
+ },
107
+ {
108
+ "epoch": 0.74,
109
+ "learning_rate": 0.000738028097574621,
110
+ "loss": 1.5352,
111
+ "step": 170
112
+ },
113
+ {
114
+ "epoch": 0.79,
115
+ "learning_rate": 0.000730470556854638,
116
+ "loss": 1.4991,
117
+ "step": 180
118
+ },
119
+ {
120
+ "epoch": 0.83,
121
+ "learning_rate": 0.0007225112184226035,
122
+ "loss": 1.495,
123
+ "step": 190
124
+ },
125
+ {
126
+ "epoch": 0.88,
127
+ "learning_rate": 0.0007141597365551446,
128
+ "loss": 1.4296,
129
+ "step": 200
130
+ },
131
+ {
132
+ "epoch": 0.92,
133
+ "learning_rate": 0.0007054262411791251,
134
+ "loss": 1.4373,
135
+ "step": 210
136
+ },
137
+ {
138
+ "epoch": 0.96,
139
+ "learning_rate": 0.0006963213255845531,
140
+ "loss": 1.4589,
141
+ "step": 220
142
+ },
143
+ {
144
+ "epoch": 1.0,
145
+ "eval_loss": 1.5730081796646118,
146
+ "eval_runtime": 181.0133,
147
+ "eval_samples_per_second": 14.883,
148
+ "eval_steps_per_second": 1.862,
149
+ "step": 228
150
+ },
151
+ {
152
+ "epoch": 1.01,
153
+ "learning_rate": 0.0006868560335754548,
154
+ "loss": 1.4361,
155
+ "step": 230
156
+ },
157
+ {
158
+ "epoch": 1.05,
159
+ "learning_rate": 0.000677041846074296,
160
+ "loss": 1.3813,
161
+ "step": 240
162
+ },
163
+ {
164
+ "epoch": 1.09,
165
+ "learning_rate": 0.000666890667196201,
166
+ "loss": 1.3511,
167
+ "step": 250
168
+ },
169
+ {
170
+ "epoch": 1.14,
171
+ "learning_rate": 0.0006564148098098617,
172
+ "loss": 1.455,
173
+ "step": 260
174
+ },
175
+ {
176
+ "epoch": 1.18,
177
+ "learning_rate": 0.0006456269806026464,
178
+ "loss": 1.4276,
179
+ "step": 270
180
+ },
181
+ {
182
+ "epoch": 1.23,
183
+ "learning_rate": 0.00063454026466803,
184
+ "loss": 1.2906,
185
+ "step": 280
186
+ },
187
+ {
188
+ "epoch": 1.27,
189
+ "learning_rate": 0.0006231681096340324,
190
+ "loss": 1.3605,
191
+ "step": 290
192
+ },
193
+ {
194
+ "epoch": 1.31,
195
+ "learning_rate": 0.0006115243093519255,
196
+ "loss": 1.3765,
197
+ "step": 300
198
+ },
199
+ {
200
+ "epoch": 1.36,
201
+ "learning_rate": 0.0005996229871649842,
202
+ "loss": 1.3846,
203
+ "step": 310
204
+ },
205
+ {
206
+ "epoch": 1.4,
207
+ "learning_rate": 0.0005874785787775835,
208
+ "loss": 1.3476,
209
+ "step": 320
210
+ },
211
+ {
212
+ "epoch": 1.44,
213
+ "learning_rate": 0.0005751058147454162,
214
+ "loss": 1.3307,
215
+ "step": 330
216
+ },
217
+ {
218
+ "epoch": 1.49,
219
+ "learning_rate": 0.0005625197026080706,
220
+ "loss": 1.3481,
221
+ "step": 340
222
+ },
223
+ {
224
+ "epoch": 1.53,
225
+ "learning_rate": 0.00054973550868564,
226
+ "loss": 1.2677,
227
+ "step": 350
228
+ },
229
+ {
230
+ "epoch": 1.58,
231
+ "learning_rate": 0.0005367687395614475,
232
+ "loss": 1.2801,
233
+ "step": 360
234
+ },
235
+ {
236
+ "epoch": 1.62,
237
+ "learning_rate": 0.0005236351232733387,
238
+ "loss": 1.2434,
239
+ "step": 370
240
+ },
241
+ {
242
+ "epoch": 1.66,
243
+ "learning_rate": 0.0005103505902363665,
244
+ "loss": 1.2472,
245
+ "step": 380
246
+ },
247
+ {
248
+ "epoch": 1.71,
249
+ "learning_rate": 0.0004969312539199984,
250
+ "loss": 1.1805,
251
+ "step": 390
252
+ },
253
+ {
254
+ "epoch": 1.75,
255
+ "learning_rate": 0.0004833933913032899,
256
+ "loss": 1.2795,
257
+ "step": 400
258
+ },
259
+ {
260
+ "epoch": 1.79,
261
+ "learning_rate": 0.0004697534231317295,
262
+ "loss": 1.1841,
263
+ "step": 410
264
+ },
265
+ {
266
+ "epoch": 1.84,
267
+ "learning_rate": 0.00045602789399970073,
268
+ "loss": 1.2189,
269
+ "step": 420
270
+ },
271
+ {
272
+ "epoch": 1.88,
273
+ "learning_rate": 0.0004422334522827224,
274
+ "loss": 1.2124,
275
+ "step": 430
276
+ },
277
+ {
278
+ "epoch": 1.93,
279
+ "learning_rate": 0.00042838682994380845,
280
+ "loss": 1.1371,
281
+ "step": 440
282
+ },
283
+ {
284
+ "epoch": 1.97,
285
+ "learning_rate": 0.00041450482223843874,
286
+ "loss": 1.1254,
287
+ "step": 450
288
+ },
289
+ {
290
+ "epoch": 2.0,
291
+ "eval_loss": 1.318668007850647,
292
+ "eval_runtime": 181.264,
293
+ "eval_samples_per_second": 14.862,
294
+ "eval_steps_per_second": 1.859,
295
+ "step": 457
296
+ },
297
+ {
298
+ "epoch": 2.01,
299
+ "learning_rate": 0.0004006042673427602,
300
+ "loss": 1.2324,
301
+ "step": 460
302
+ },
303
+ {
304
+ "epoch": 2.06,
305
+ "learning_rate": 0.0003867020259297277,
306
+ "loss": 1.2353,
307
+ "step": 470
308
+ },
309
+ {
310
+ "epoch": 2.1,
311
+ "learning_rate": 0.00037281496071795675,
312
+ "loss": 1.2029,
313
+ "step": 480
314
+ },
315
+ {
316
+ "epoch": 2.14,
317
+ "learning_rate": 0.0003589599160180951,
318
+ "loss": 1.1946,
319
+ "step": 490
320
+ },
321
+ {
322
+ "epoch": 2.19,
323
+ "learning_rate": 0.0003451536973015218,
324
+ "loss": 1.2571,
325
+ "step": 500
326
+ },
327
+ {
328
+ "epoch": 2.23,
329
+ "learning_rate": 0.0003314130508161583,
330
+ "loss": 1.1964,
331
+ "step": 510
332
+ },
333
+ {
334
+ "epoch": 2.28,
335
+ "learning_rate": 0.0003177546432741117,
336
+ "loss": 1.2171,
337
+ "step": 520
338
+ },
339
+ {
340
+ "epoch": 2.32,
341
+ "learning_rate": 0.00030419504163579317,
342
+ "loss": 1.1815,
343
+ "step": 530
344
+ },
345
+ {
346
+ "epoch": 2.36,
347
+ "learning_rate": 0.00029075069301502925,
348
+ "loss": 1.1589,
349
+ "step": 540
350
+ },
351
+ {
352
+ "epoch": 2.41,
353
+ "learning_rate": 0.000277437904729541,
354
+ "loss": 1.1154,
355
+ "step": 550
356
+ },
357
+ {
358
+ "epoch": 2.45,
359
+ "learning_rate": 0.0002642728245209895,
360
+ "loss": 1.1195,
361
+ "step": 560
362
+ },
363
+ {
364
+ "epoch": 2.49,
365
+ "learning_rate": 0.0002512714209685778,
366
+ "loss": 1.1485,
367
+ "step": 570
368
+ },
369
+ {
370
+ "epoch": 2.54,
371
+ "learning_rate": 0.00023844946411996905,
372
+ "loss": 1.1151,
373
+ "step": 580
374
+ },
375
+ {
376
+ "epoch": 2.58,
377
+ "learning_rate": 0.0002258225063630134,
378
+ "loss": 1.1342,
379
+ "step": 590
380
+ },
381
+ {
382
+ "epoch": 2.63,
383
+ "learning_rate": 0.00021340586356148388,
384
+ "loss": 1.1106,
385
+ "step": 600
386
+ },
387
+ {
388
+ "epoch": 2.67,
389
+ "learning_rate": 0.0002012145964777057,
390
+ "loss": 1.0693,
391
+ "step": 610
392
+ },
393
+ {
394
+ "epoch": 2.71,
395
+ "learning_rate": 0.00018926349250461,
396
+ "loss": 1.1118,
397
+ "step": 620
398
+ },
399
+ {
400
+ "epoch": 2.76,
401
+ "learning_rate": 0.00017756704772937113,
402
+ "loss": 1.097,
403
+ "step": 630
404
+ },
405
+ {
406
+ "epoch": 2.8,
407
+ "learning_rate": 0.00016613944935038317,
408
+ "loss": 1.0072,
409
+ "step": 640
410
+ },
411
+ {
412
+ "epoch": 2.84,
413
+ "learning_rate": 0.000154994558468902,
414
+ "loss": 1.0244,
415
+ "step": 650
416
+ },
417
+ {
418
+ "epoch": 2.89,
419
+ "learning_rate": 0.0001441458932762289,
420
+ "loss": 1.0308,
421
+ "step": 660
422
+ },
423
+ {
424
+ "epoch": 2.93,
425
+ "learning_rate": 0.00013360661265682426,
426
+ "loss": 0.9882,
427
+ "step": 670
428
+ },
429
+ {
430
+ "epoch": 2.98,
431
+ "learning_rate": 0.00012338950022724405,
432
+ "loss": 0.9938,
433
+ "step": 680
434
+ },
435
+ {
436
+ "epoch": 3.0,
437
+ "eval_loss": 1.1857038736343384,
438
+ "eval_runtime": 181.1024,
439
+ "eval_samples_per_second": 14.876,
440
+ "eval_steps_per_second": 1.861,
441
+ "step": 685
442
+ },
443
+ {
444
+ "epoch": 3.02,
445
+ "learning_rate": 0.00011350694883025702,
446
+ "loss": 1.0906,
447
+ "step": 690
448
+ },
449
+ {
450
+ "epoch": 3.06,
451
+ "learning_rate": 0.00010397094550294988,
452
+ "loss": 1.1792,
453
+ "step": 700
454
+ },
455
+ {
456
+ "epoch": 3.11,
457
+ "learning_rate": 9.4793056937056e-05,
458
+ "loss": 1.0951,
459
+ "step": 710
460
+ },
461
+ {
462
+ "epoch": 3.15,
463
+ "learning_rate": 8.598441544914002e-05,
464
+ "loss": 1.1168,
465
+ "step": 720
466
+ },
467
+ {
468
+ "epoch": 3.19,
469
+ "learning_rate": 7.755570547765905e-05,
470
+ "loss": 1.0971,
471
+ "step": 730
472
+ },
473
+ {
474
+ "epoch": 3.24,
475
+ "learning_rate": 6.951715062327716e-05,
476
+ "loss": 1.0359,
477
+ "step": 740
478
+ },
479
+ {
480
+ "epoch": 3.28,
481
+ "learning_rate": 6.187850124815228e-05,
482
+ "loss": 1.077,
483
+ "step": 750
484
+ },
485
+ {
486
+ "epoch": 3.33,
487
+ "learning_rate": 5.4649022649238026e-05,
488
+ "loss": 1.0996,
489
+ "step": 760
490
+ },
491
+ {
492
+ "epoch": 3.37,
493
+ "learning_rate": 4.783748381994562e-05,
494
+ "loss": 1.0043,
495
+ "step": 770
496
+ },
497
+ {
498
+ "epoch": 3.41,
499
+ "learning_rate": 4.145214681379591e-05,
500
+ "loss": 1.1422,
501
+ "step": 780
502
+ },
503
+ {
504
+ "epoch": 3.46,
505
+ "learning_rate": 3.550075672296503e-05,
506
+ "loss": 1.1366,
507
+ "step": 790
508
+ },
509
+ {
510
+ "epoch": 3.5,
511
+ "learning_rate": 2.9990532283877747e-05,
512
+ "loss": 1.0587,
513
+ "step": 800
514
+ },
515
+ {
516
+ "epoch": 3.54,
517
+ "learning_rate": 2.492815712124332e-05,
518
+ "loss": 1.1301,
519
+ "step": 810
520
+ },
521
+ {
522
+ "epoch": 3.59,
523
+ "learning_rate": 2.0319771641155883e-05,
524
+ "loss": 1.0567,
525
+ "step": 820
526
+ },
527
+ {
528
+ "epoch": 3.63,
529
+ "learning_rate": 1.617096558309071e-05,
530
+ "loss": 1.1119,
531
+ "step": 830
532
+ },
533
+ {
534
+ "epoch": 3.68,
535
+ "learning_rate": 1.2486771239831942e-05,
536
+ "loss": 1.1186,
537
+ "step": 840
538
+ },
539
+ {
540
+ "epoch": 3.72,
541
+ "learning_rate": 9.271657353555046e-06,
542
+ "loss": 1.0765,
543
+ "step": 850
544
+ },
545
+ {
546
+ "epoch": 3.76,
547
+ "learning_rate": 6.529523695467422e-06,
548
+ "loss": 1.0678,
549
+ "step": 860
550
+ },
551
+ {
552
+ "epoch": 3.81,
553
+ "learning_rate": 4.263696335582372e-06,
554
+ "loss": 1.1022,
555
+ "step": 870
556
+ },
557
+ {
558
+ "epoch": 3.85,
559
+ "learning_rate": 2.476923608363819e-06,
560
+ "loss": 1.0498,
561
+ "step": 880
562
+ },
563
+ {
564
+ "epoch": 3.89,
565
+ "learning_rate": 1.1713727791349433e-06,
566
+ "loss": 1.0907,
567
+ "step": 890
568
+ },
569
+ {
570
+ "epoch": 3.94,
571
+ "learning_rate": 3.4862741529444126e-07,
572
+ "loss": 1.0615,
573
+ "step": 900
574
+ },
575
+ {
576
+ "epoch": 3.98,
577
+ "learning_rate": 9.685465529235211e-09,
578
+ "loss": 1.0461,
579
+ "step": 910
580
+ },
581
+ {
582
+ "epoch": 3.99,
583
+ "eval_loss": 1.1649552583694458,
584
+ "eval_runtime": 181.1967,
585
+ "eval_samples_per_second": 14.868,
586
+ "eval_steps_per_second": 1.86,
587
+ "step": 912
588
+ }
589
+ ],
590
+ "logging_steps": 10,
591
+ "max_steps": 912,
592
+ "num_train_epochs": 4,
593
+ "save_steps": 500,
594
+ "total_flos": 1.1665671520864666e+17,
595
+ "trial_name": null,
596
+ "trial_params": null
597
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdaa6b6ca3d18933180a2b04e43e87f74d8e989f6fe9c9b5d31fdeba7acadaa1
3
+ size 4091