carlesoctav commited on
Commit
d3b88f1
1 Parent(s): 4127b6d

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +20 -28
train.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import tensorflow as tf
3
  from transformers import TFXLMRobertaModel, AutoTokenizer, TFAutoModel
4
  from datasets import load_dataset, concatenate_datasets
@@ -6,8 +5,6 @@ from datetime import datetime
6
  import logging
7
  from pyprojroot.here import here
8
 
9
-
10
-
11
  class mean_pooling_layer(tf.keras.layers.Layer):
12
  def __init__(self):
13
  super(mean_pooling_layer, self).__init__()
@@ -39,6 +36,7 @@ def create_model():
39
  output_id = base_student_model.roberta(input_ids_id, attention_mask=attention_mask_id).last_hidden_state[:,0,:]
40
 
41
  student_model = tf.keras.Model(inputs=[input_ids_en, attention_mask_en, input_ids_id, attention_mask_id], outputs=[output_en, output_id])
 
42
  return student_model
43
 
44
  class sentence_translation_metric(tf.keras.callbacks.Callback):
@@ -74,21 +72,19 @@ class sentence_translation_metric(tf.keras.callbacks.Callback):
74
  logs["val_avg_acc"] = avg_acc
75
 
76
 
77
- class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
78
- def __init__(self, d_model, warmup_steps=100000):
79
  super().__init__()
80
-
81
- self.d_model = d_model
82
- self.d_model = tf.cast(self.d_model, tf.float32)
83
-
84
  self.warmup_steps = warmup_steps
85
 
86
  def __call__(self, step):
87
- step = tf.cast(step, dtype=tf.float32)
88
- arg1 = tf.math.rsqrt(step)
89
- arg2 = step * (self.warmup_steps ** -1.5)
 
90
 
91
- return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
92
 
93
 
94
  if __name__ == "__main__":
@@ -101,8 +97,8 @@ if __name__ == "__main__":
101
  dataset_1 = concatenate_datasets([dataset_1, dataset[split]])
102
 
103
 
104
- batch_size = 384
105
- dataset = dataset_1.train_test_split(test_size=0.01, shuffle=True)
106
  train_dataset = dataset["train"]
107
  val_dataset = dataset["test"]
108
  print(val_dataset.shape)
@@ -127,7 +123,8 @@ if __name__ == "__main__":
127
  val_dataset = val_dataset.batch(batch_size, drop_remainder=True).cache()
128
 
129
 
130
- learning_rate = CustomSchedule(384)
 
131
 
132
  optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
133
  epsilon=1e-9)
@@ -137,7 +134,7 @@ if __name__ == "__main__":
137
  loss = tf.keras.losses.MeanSquaredError()
138
 
139
  date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
140
- output_path = here(f"disk/model/{date_time}/multiqa-mpnet-dot-v1.h5")
141
 
142
  model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
143
  filepath = output_path,
@@ -146,16 +143,16 @@ if __name__ == "__main__":
146
  mode = 'auto',
147
  verbose = 1,
148
  save_best_only = True,
149
- initial_value_threshold = 0.1
150
  )
151
 
152
  early_stopping = tf.keras.callbacks.EarlyStopping(
153
  monitor = "val_avg_acc",
154
  mode = 'auto',
155
  restore_best_weights=False,
156
- patience = 2,
157
  verbose=1,
158
- start_from_epoch = 25,
159
  )
160
 
161
 
@@ -169,13 +166,7 @@ if __name__ == "__main__":
169
  append = False
170
  )
171
 
172
- reduce_rl = tf.keras.callbacks.ReduceLROnPlateau(
173
- monitor = "",
174
- factor = 0.1,
175
- patience = 2,
176
- min_lr = 1e-6,
177
- verbose = 1
178
- )
179
 
180
 
181
  callbacks = [sentence_translation_metric(), model_checkpoint, csv_logger,early_stopping]
@@ -192,7 +183,8 @@ if __name__ == "__main__":
192
 
193
  student_model.fit(train_dataset, epochs=20, validation_data=val_dataset, callbacks=callbacks)
194
 
195
- last_epoch_save = here(f"disk/model/last_epoch/{date_time}/multiqa-mpnet-dot-v1.h5")
 
196
  student_model.save_weights(last_epoch_save)
197
 
198
 
 
 
1
  import tensorflow as tf
2
  from transformers import TFXLMRobertaModel, AutoTokenizer, TFAutoModel
3
  from datasets import load_dataset, concatenate_datasets
 
5
  import logging
6
  from pyprojroot.here import here
7
 
 
 
8
  class mean_pooling_layer(tf.keras.layers.Layer):
9
  def __init__(self):
10
  super(mean_pooling_layer, self).__init__()
 
36
  output_id = base_student_model.roberta(input_ids_id, attention_mask=attention_mask_id).last_hidden_state[:,0,:]
37
 
38
  student_model = tf.keras.Model(inputs=[input_ids_en, attention_mask_en, input_ids_id, attention_mask_id], outputs=[output_en, output_id])
39
+ student_model.load_weights("disk/model/2023-05-25_07-52-43/multiqa-Mmini-L6-H384.h5")
40
  return student_model
41
 
42
  class sentence_translation_metric(tf.keras.callbacks.Callback):
 
72
  logs["val_avg_acc"] = avg_acc
73
 
74
 
75
+ class ConstantScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
76
+ def __init__(self, max_lr, warmup_steps=5000):
77
  super().__init__()
78
+ self.max_lr = tf.cast(max_lr, tf.float32)
 
 
 
79
  self.warmup_steps = warmup_steps
80
 
81
  def __call__(self, step):
82
+ step = tf.cast(step, tf.float32)
83
+ condition = tf.cond(step < self.warmup_steps, lambda: step / self.warmup_steps, lambda: 1.0)
84
+ return self.max_lr * condition
85
+
86
 
87
+
88
 
89
 
90
  if __name__ == "__main__":
 
97
  dataset_1 = concatenate_datasets([dataset_1, dataset[split]])
98
 
99
 
100
+ batch_size = 512
101
+ dataset = dataset_1.train_test_split(test_size=0.005, shuffle=True)
102
  train_dataset = dataset["train"]
103
  val_dataset = dataset["test"]
104
  print(val_dataset.shape)
 
123
  val_dataset = val_dataset.batch(batch_size, drop_remainder=True).cache()
124
 
125
 
126
+ learning_rate = ConstantScheduler(1e-3, warmup_steps=10000)
127
+
128
 
129
  optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
130
  epsilon=1e-9)
 
134
  loss = tf.keras.losses.MeanSquaredError()
135
 
136
  date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
137
+ output_path = here(f"disk/model/{date_time}/multiqa-Mmini-L6-H384.h5")
138
 
139
  model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
140
  filepath = output_path,
 
143
  mode = 'auto',
144
  verbose = 1,
145
  save_best_only = True,
146
+ initial_value_threshold = 0.5,
147
  )
148
 
149
  early_stopping = tf.keras.callbacks.EarlyStopping(
150
  monitor = "val_avg_acc",
151
  mode = 'auto',
152
  restore_best_weights=False,
153
+ patience = 4,
154
  verbose=1,
155
+ start_from_epoch = 5,
156
  )
157
 
158
 
 
166
  append = False
167
  )
168
 
169
+
 
 
 
 
 
 
170
 
171
 
172
  callbacks = [sentence_translation_metric(), model_checkpoint, csv_logger,early_stopping]
 
183
 
184
  student_model.fit(train_dataset, epochs=20, validation_data=val_dataset, callbacks=callbacks)
185
 
186
+
187
+ last_epoch_save = here(f"disk/model/last_epoch/{date_time}.h5")
188
  student_model.save_weights(last_epoch_save)
189
 
190