finnstrom3693 commited on
Commit
54810d7
1 Parent(s): c42ab1b

add build call

Browse files
Files changed (1) hide show
  1. modeling4.py +9 -12
modeling4.py CHANGED
@@ -1,4 +1,3 @@
1
- # @title Model Architecture
2
  import tensorflow as tf
3
  from tensorflow.keras import layers, activations, initializers
4
 
@@ -28,13 +27,19 @@ class MiniSunModel(tf.keras.Model):
28
  self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
29
  self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
30
 
31
- # Transformer decoder blocks
32
- self.decoder_blocks = [self._build_decoder_block() for _ in range(config.num_hidden_layers)]
33
 
34
  # Final normalization and head
35
  self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
36
  self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
37
 
 
 
 
 
 
 
38
  def _build_decoder_block(self):
39
  # Decoder block consisting of multi-head attention and feed-forward layers
40
  return [
@@ -91,7 +96,6 @@ class MiniSunModel(tf.keras.Model):
91
 
92
  def compute_loss(self, labels, logits):
93
  """Computes the loss between labels and logits."""
94
- # Ensure labels and logits are not None
95
  if labels is None or logits is None:
96
  raise ValueError("Labels and logits cannot be None.")
97
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
@@ -117,7 +121,6 @@ class MiniSunModel(tf.keras.Model):
117
 
118
  return {m.name: m.result() for m in self.metrics}
119
 
120
-
121
  def create_model(config):
122
  model = MiniSunModel(config)
123
 
@@ -137,7 +140,6 @@ def cosine_annealing_with_warmup(step, config):
137
  if step < warmup_steps:
138
  return config.learning_rate * (step / warmup_steps)
139
  else:
140
- # Calculate the cosine decay
141
  cos_step = step - warmup_steps
142
  total_cos_steps = config.total_steps - warmup_steps
143
  return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
@@ -146,21 +148,16 @@ def cosine_annealing_with_restarts(step, config, restart_period, cycle_num):
146
  """Learning rate schedule with warm-up and cosine annealing with restarts."""
147
  warmup_steps = int(config.total_steps * config.warmup_steps)
148
 
149
- # Determine the current cycle based on step and restart_period
150
  current_cycle = step // restart_period
151
-
152
- # Calculate the effective step within the current cycle
153
  effective_step = step % restart_period
154
 
155
  if effective_step < warmup_steps:
156
  return config.learning_rate * (effective_step / warmup_steps)
157
  else:
158
- # Calculate the cosine decay within the current cycle
159
  cos_step = effective_step - warmup_steps
160
  total_cos_steps = restart_period - warmup_steps
161
  return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
162
 
163
-
164
  # Configuration
165
  config = MiniSunConfig()
166
 
@@ -169,4 +166,4 @@ model = create_model(config)
169
 
170
  # Create a LearningRateScheduler callback
171
  lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
172
- #lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config, restart_period=1000, cycle_num=1))
 
 
1
  import tensorflow as tf
2
  from tensorflow.keras import layers, activations, initializers
3
 
 
27
  self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
28
  self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
29
 
30
+ # Initialize an empty list for decoder blocks
31
+ self.decoder_blocks = []
32
 
33
  # Final normalization and head
34
  self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
35
  self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
36
 
37
+ def build(self, input_shape):
38
+ # Create transformer decoder blocks based on the model configuration
39
+ self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
40
+ # Call the superclass's build method
41
+ super(MiniSunModel, self).build(input_shape)
42
+
43
  def _build_decoder_block(self):
44
  # Decoder block consisting of multi-head attention and feed-forward layers
45
  return [
 
96
 
97
  def compute_loss(self, labels, logits):
98
  """Computes the loss between labels and logits."""
 
99
  if labels is None or logits is None:
100
  raise ValueError("Labels and logits cannot be None.")
101
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
 
121
 
122
  return {m.name: m.result() for m in self.metrics}
123
 
 
124
  def create_model(config):
125
  model = MiniSunModel(config)
126
 
 
140
  if step < warmup_steps:
141
  return config.learning_rate * (step / warmup_steps)
142
  else:
 
143
  cos_step = step - warmup_steps
144
  total_cos_steps = config.total_steps - warmup_steps
145
  return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
 
148
  """Learning rate schedule with warm-up and cosine annealing with restarts."""
149
  warmup_steps = int(config.total_steps * config.warmup_steps)
150
 
 
151
  current_cycle = step // restart_period
 
 
152
  effective_step = step % restart_period
153
 
154
  if effective_step < warmup_steps:
155
  return config.learning_rate * (effective_step / warmup_steps)
156
  else:
 
157
  cos_step = effective_step - warmup_steps
158
  total_cos_steps = restart_period - warmup_steps
159
  return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
160
 
 
161
  # Configuration
162
  config = MiniSunConfig()
163
 
 
166
 
167
  # Create a LearningRateScheduler callback
168
  lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
169
+ # lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config, restart_period=1000, cycle_num=1))