finnstrom3693's picture
add build call
54810d7 verified
raw
history blame
7.39 kB
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers
class MiniSunConfig:
def __init__(self, vocab_size=30522, max_position_embeddings=1024, hidden_size=512,
num_attention_heads=8, intermediate_size=2048, num_hidden_layers=8,
dropout_rate=0.1, weight_decay=0.01, learning_rate=1e-4, total_steps=2500, warmup_steps=0.2):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.dropout_rate = dropout_rate
self.weight_decay = weight_decay
self.learning_rate = learning_rate
self.total_steps = total_steps
self.warmup_steps = warmup_steps
@tf.keras.utils.register_keras_serializable()
class MiniSunModel(tf.keras.Model):
def __init__(self, config):
super(MiniSunModel, self).__init__()
self.config = config
# Embedding layers for token and position
self.token_embedding = layers.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = layers.Embedding(config.max_position_embeddings, config.hidden_size)
# Initialize an empty list for decoder blocks
self.decoder_blocks = []
# Final normalization and head
self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
self.lm_head = layers.Dense(config.vocab_size, kernel_initializer=initializers.he_normal())
def build(self, input_shape):
# Create transformer decoder blocks based on the model configuration
self.decoder_blocks = [self._build_decoder_block() for _ in range(self.config.num_hidden_layers)]
# Call the superclass's build method
super(MiniSunModel, self).build(input_shape)
def _build_decoder_block(self):
# Decoder block consisting of multi-head attention and feed-forward layers
return [
layers.MultiHeadAttention(num_heads=self.config.num_attention_heads, key_dim=self.config.hidden_size,
kernel_initializer=initializers.he_normal()),
layers.LayerNormalization(epsilon=1e-6),
layers.Dense(self.config.intermediate_size, activation=activations.elu,
kernel_initializer=initializers.he_normal()),
layers.Dense(self.config.hidden_size, kernel_initializer=initializers.he_normal()),
layers.Dropout(self.config.dropout_rate)
]
def call(self, inputs, attention_mask=None, training=False):
input_ids = inputs['input_ids']
position_ids = tf.range(start=0, limit=tf.shape(input_ids)[-1], delta=1)
# Token and position embeddings
embeddings = self.token_embedding(input_ids) + self.position_embedding(position_ids)
# Adjust attention mask to correct shape [batch_size, 1, 1, seq_len]
if attention_mask is not None:
attention_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
# Apply decoder blocks
hidden_states = embeddings
for mha, norm, ffn1, ffn2, dropout in self.decoder_blocks:
attn_output = mha(hidden_states, hidden_states, attention_mask=attention_mask, training=training)
attn_output = dropout(attn_output, training=training)
hidden_states = norm(attn_output + hidden_states) # Add & Norm
# Feed-forward layers
ffn_output = ffn1(hidden_states)
ffn_output = ffn2(ffn_output)
ffn_output = dropout(ffn_output, training=training)
hidden_states = norm(ffn_output + hidden_states) # Add & Norm
# Final layer normalization
hidden_states = self.layer_norm(hidden_states)
# LM Head for token generation
logits = self.lm_head(hidden_states)
return logits
def get_config(self):
# Return the configuration of the model
return {
'config': self.config.__dict__
}
@classmethod
def from_config(cls, config):
# Create an instance of the model from the config
return cls(MiniSunConfig(**config['config']))
def compute_loss(self, labels, logits):
"""Computes the loss between labels and logits."""
if labels is None or logits is None:
raise ValueError("Labels and logits cannot be None.")
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def train_step(self, data):
inputs, labels = data
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
with tf.GradientTape() as tape:
logits = self(inputs, training=True)
loss = self.compute_loss(labels, logits)
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
logits_for_metrics = tf.argmax(logits, axis=-1)
logits_for_metrics = tf.reshape(logits_for_metrics, [-1])
labels_for_metrics = tf.reshape(labels, [-1])
for metric in self.metrics:
metric.update_state(labels_for_metrics, logits_for_metrics)
return {m.name: m.result() for m in self.metrics}
def create_model(config):
model = MiniSunModel(config)
# Optimizer with weight decay
optimizer = tf.keras.optimizers.AdamW(learning_rate=config.learning_rate, weight_decay=config.weight_decay)
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
return model
def cosine_annealing_with_warmup(step, config):
"""Learning rate schedule with warm-up and cosine annealing."""
warmup_steps = int(config.total_steps * config.warmup_steps)
if step < warmup_steps:
return config.learning_rate * (step / warmup_steps)
else:
cos_step = step - warmup_steps
total_cos_steps = config.total_steps - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
def cosine_annealing_with_restarts(step, config, restart_period, cycle_num):
"""Learning rate schedule with warm-up and cosine annealing with restarts."""
warmup_steps = int(config.total_steps * config.warmup_steps)
current_cycle = step // restart_period
effective_step = step % restart_period
if effective_step < warmup_steps:
return config.learning_rate * (effective_step / warmup_steps)
else:
cos_step = effective_step - warmup_steps
total_cos_steps = restart_period - warmup_steps
return 0.5 * config.learning_rate * (1 + tf.cos(tf.constant(np.pi) * cos_step / total_cos_steps))
# Configuration
config = MiniSunConfig()
# Initialize model with He initialization
model = create_model(config)
# Create a LearningRateScheduler callback
lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_warmup(step, config))
# lr_scheduler_with_restarts = tf.keras.callbacks.LearningRateScheduler(lambda step: cosine_annealing_with_restarts(step, config, restart_period=1000, cycle_num=1))