carlesoctav commited on
Commit
4127b6d
1 Parent(s): e9bf3eb

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +199 -0
train.py CHANGED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import tensorflow as tf
3
+ from transformers import TFXLMRobertaModel, AutoTokenizer, TFAutoModel
4
+ from datasets import load_dataset, concatenate_datasets
5
+ from datetime import datetime
6
+ import logging
7
+ from pyprojroot.here import here
8
+
9
+
10
+
11
+ class mean_pooling_layer(tf.keras.layers.Layer):
12
+ def __init__(self):
13
+ super(mean_pooling_layer, self).__init__()
14
+
15
+ def call(self, inputs):
16
+ token_embeddings = inputs[0]
17
+ attention_mask = inputs[1]
18
+ input_mask_expanded = tf.cast(
19
+ tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
20
+ tf.float32
21
+ )
22
+
23
+ embeddings = tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1) / tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
24
+ return embeddings
25
+
26
+ def get_config(self):
27
+ config = super(mean_pooling_layer, self).get_config()
28
+ return config
29
+
30
+
31
+ def create_model():
32
+ base_student_model = TFAutoModel.from_pretrained("nreimers/mMiniLMv2-L6-H384-distilled-from-XLMR-Large",from_pt=True)
33
+ input_ids_en = tf.keras.layers.Input(shape=(256,),name='input_ids_en', dtype=tf.int32)
34
+ attention_mask_en = tf.keras.layers.Input(shape=(256,), name='attention_mask_en', dtype=tf.int32)
35
+ input_ids_id = tf.keras.layers.Input(shape=(256,),name='input_ids_id', dtype=tf.int32)
36
+ attention_mask_id = tf.keras.layers.Input(shape=(256,), name='attention_mask_id', dtype=tf.int32)
37
+
38
+ output_en = base_student_model.roberta(input_ids_en, attention_mask=attention_mask_en).last_hidden_state[:,0,:]
39
+ output_id = base_student_model.roberta(input_ids_id, attention_mask=attention_mask_id).last_hidden_state[:,0,:]
40
+
41
+ student_model = tf.keras.Model(inputs=[input_ids_en, attention_mask_en, input_ids_id, attention_mask_id], outputs=[output_en, output_id])
42
+ return student_model
43
+
44
+ class sentence_translation_metric(tf.keras.callbacks.Callback):
45
+ def on_epoch_end(self,epoch,logs):
46
+ embeddings_en, embeddings_id = self.model.predict(val_dataset, verbose=1)
47
+ # get the embeddings
48
+ # compute the cosine similarity between the two
49
+ #normalize the embeddings
50
+ embeddings_en = tf.math.l2_normalize(embeddings_en, axis=1)
51
+ embeddings_id = tf.math.l2_normalize(embeddings_id, axis=1)
52
+ similarity_matrix = tf.matmul(embeddings_en, embeddings_id, transpose_b=True)
53
+ # get the mean similarity
54
+ correct_en_id = 0
55
+ for i in range(similarity_matrix.shape[0]):
56
+ if tf.math.argmax(similarity_matrix[i]) == i:
57
+ correct_en_id += 1
58
+
59
+ similarity_matrix_T = tf.transpose(similarity_matrix)
60
+ correct_id_en = 0
61
+ for i in range(similarity_matrix_T.shape[0]):
62
+ if tf.math.argmax(similarity_matrix_T[i]) == i:
63
+ correct_id_en += 1
64
+
65
+ acc_en_id = correct_en_id / similarity_matrix.shape[0]
66
+ acc_id_en = correct_id_en / similarity_matrix_T.shape[0]
67
+ avg_acc = (acc_en_id + acc_id_en) / 2
68
+ print(f"translation accuracy from english to indonesian = {acc_en_id}")
69
+ print(f"translation accuracy from indonesian to english = {acc_id_en}")
70
+ print(f"average translation accuracy = {avg_acc}")
71
+
72
+ logs["val_acc_en_id"] = acc_en_id
73
+ logs["val_acc_id_en"] = acc_id_en
74
+ logs["val_avg_acc"] = avg_acc
75
+
76
+
77
+ class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
78
+ def __init__(self, d_model, warmup_steps=100000):
79
+ super().__init__()
80
+
81
+ self.d_model = d_model
82
+ self.d_model = tf.cast(self.d_model, tf.float32)
83
+
84
+ self.warmup_steps = warmup_steps
85
+
86
+ def __call__(self, step):
87
+ step = tf.cast(step, dtype=tf.float32)
88
+ arg1 = tf.math.rsqrt(step)
89
+ arg2 = step * (self.warmup_steps ** -1.5)
90
+
91
+ return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ num_data = 0
96
+ dataset = load_dataset("carlesoctav/en-id-parallel-sentences-embedding")
97
+
98
+ dataset_1 = dataset["combinedtech"]
99
+
100
+ for split in dataset:
101
+ dataset_1 = concatenate_datasets([dataset_1, dataset[split]])
102
+
103
+
104
+ batch_size = 384
105
+ dataset = dataset_1.train_test_split(test_size=0.01, shuffle=True)
106
+ train_dataset = dataset["train"]
107
+ val_dataset = dataset["test"]
108
+ print(val_dataset.shape)
109
+
110
+ train_dataset = train_dataset.to_tf_dataset(
111
+ columns=["input_ids_en", "attention_mask_en", "input_ids_id", "attention_mask_id"],
112
+ label_cols="target_embedding",
113
+ batch_size=batch_size,
114
+ ).unbatch()
115
+
116
+ val_dataset = val_dataset.to_tf_dataset(
117
+ columns=["input_ids_en", "attention_mask_en", "input_ids_id", "attention_mask_id"],
118
+ label_cols="target_embedding",
119
+ batch_size=batch_size,
120
+ ).unbatch()
121
+
122
+ #check feature
123
+ print(train_dataset.element_spec)
124
+ print(val_dataset.element_spec)
125
+
126
+ train_dataset = train_dataset.batch(batch_size, drop_remainder=True).cache()
127
+ val_dataset = val_dataset.batch(batch_size, drop_remainder=True).cache()
128
+
129
+
130
+ learning_rate = CustomSchedule(384)
131
+
132
+ optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
133
+ epsilon=1e-9)
134
+
135
+
136
+
137
+ loss = tf.keras.losses.MeanSquaredError()
138
+
139
+ date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
140
+ output_path = here(f"disk/model/{date_time}/multiqa-mpnet-dot-v1.h5")
141
+
142
+ model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
143
+ filepath = output_path,
144
+ save_weights_only = True,
145
+ monitor = "val_avg_acc",
146
+ mode = 'auto',
147
+ verbose = 1,
148
+ save_best_only = True,
149
+ initial_value_threshold = 0.1
150
+ )
151
+
152
+ early_stopping = tf.keras.callbacks.EarlyStopping(
153
+ monitor = "val_avg_acc",
154
+ mode = 'auto',
155
+ restore_best_weights=False,
156
+ patience = 2,
157
+ verbose=1,
158
+ start_from_epoch = 25,
159
+ )
160
+
161
+
162
+ # tensor_board = tf.keras.callbacks.TensorBoard(
163
+ # log_dir = "gs://dicoding-capstone/output/logs/"+date_time
164
+ # )
165
+
166
+ csv_logger = tf.keras.callbacks.CSVLogger(
167
+ filename = here(f"disk/performance_logs/log-{date_time}.csv"),
168
+ separator = ",",
169
+ append = False
170
+ )
171
+
172
+ reduce_rl = tf.keras.callbacks.ReduceLROnPlateau(
173
+ monitor = "",
174
+ factor = 0.1,
175
+ patience = 2,
176
+ min_lr = 1e-6,
177
+ verbose = 1
178
+ )
179
+
180
+
181
+ callbacks = [sentence_translation_metric(), model_checkpoint, csv_logger,early_stopping]
182
+
183
+
184
+ cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver("local")
185
+ tf.config.experimental_connect_to_cluster(cluster_resolver)
186
+ tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
187
+ strategy = tf.distribute.TPUStrategy(cluster_resolver)
188
+
189
+ with strategy.scope():
190
+ student_model = create_model()
191
+ student_model.compile(optimizer=optimizer, loss=loss)
192
+
193
+ student_model.fit(train_dataset, epochs=20, validation_data=val_dataset, callbacks=callbacks)
194
+
195
+ last_epoch_save = here(f"disk/model/last_epoch/{date_time}/multiqa-mpnet-dot-v1.h5")
196
+ student_model.save_weights(last_epoch_save)
197
+
198
+
199
+