Zamba2-1.2B-instruct-Dutch / optimize_lr.py
ssmits's picture
Upload 2 files
90f46ea verified
import optuna
from transformers import (
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
Trainer, DataCollatorForLanguageModeling
)
import torch
from datasets import load_dataset
import numpy as np
import gc
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ConstantKernel, Matern
import matplotlib.pyplot as plt
from scipy.stats import norm
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
from transformers import TrainerCallback
import argparse
# Configuration parameters
num_trials = 10 # Adjust this value to control the number of optimization trials
DATASET = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft[:1000]")
CONTEXT_WINDOW = 1024
# Initialize tokenizer once
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
def prepare_chat_format(examples):
chats = []
for messages in examples['messages']:
try:
chat = tokenizer.apply_chat_template(
messages,
tokenize=True,
max_length=CONTEXT_WINDOW,
truncation=True,
return_tensors=None
)
chats.append(chat)
except Exception as e:
print(f"Error applying chat template: {e}")
print("Fallback format if chat template fails")
text = ""
for message in messages:
role = message["role"]
content = message["content"]
text += f"<|{role}|>\n{content}</s>\n"
chat = tokenizer(
text,
max_length=CONTEXT_WINDOW,
truncation=True,
return_tensors=None
)["input_ids"]
chats.append(chat)
return {"input_ids": chats}
# Prepare dataset once
tokenized_dataset = DATASET.map(
prepare_chat_format,
batched=True,
remove_columns=DATASET.column_names
)
def clear_memory():
"""Clear GPU memory between trials"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
class LossCallback(TrainerCallback):
def __init__(self):
self.losses = []
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is not None and "loss" in logs:
self.losses.append(logs["loss"])
def objective(trial):
# Clear memory from previous trial
clear_memory()
lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
# Initialize model with fresh state
torch.manual_seed(42)
model = AutoModelForCausalLM.from_pretrained(
"Zyphra/Zamba2-1.2B",
torch_dtype=torch.bfloat16,
device_map="auto"
)
model.config.pad_token_id = tokenizer.pad_token_id
# Calculate steps with larger batch size
batch_size = 4 # Increased from 1
grad_accum_steps = 8 # Decreased from 32 since we increased batch size
effective_batch_size = batch_size * grad_accum_steps # Still 32 total
total_steps = len(tokenized_dataset) // effective_batch_size
# Training arguments
training_args = TrainingArguments(
output_dir=f"./optuna_runs/trial_{trial.number}",
num_train_epochs=1,
per_device_train_batch_size=batch_size, # Increased
gradient_accumulation_steps=grad_accum_steps, # Decreased
logging_steps=max(total_steps // 20, 1),
learning_rate=lr,
weight_decay=0.01,
fp16=False,
bf16=True,
warmup_steps=total_steps // 10,
save_steps=1000000,
save_total_limit=None,
report_to="none",
seed=42,
dataloader_num_workers=4, # Added for faster data loading
gradient_checkpointing=True, # Added to optimize memory usage
max_grad_norm=1.0 # Added for stability
)
print(f"\nTrial {trial.number}:")
print(f"Learning rate: {lr}")
print(f"Total steps: {total_steps}")
print(f"Logging every {training_args.logging_steps} steps")
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = model
def _move_model_to_device(self, model, device):
pass
# Initialize callback
loss_callback = LossCallback()
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
callbacks=[loss_callback] # Use the proper callback
)
try:
train_result = trainer.train()
# Calculate mean of last 20% of losses
losses = loss_callback.losses # Get losses from callback
n_losses = max(len(losses) // 5, 1)
final_losses = losses[-n_losses:]
mean_loss = np.mean(final_losses) if final_losses else float('inf')
# Clean up
del model
del trainer
clear_memory()
return mean_loss
except Exception as e:
print(f"Trial failed with error: {e}")
# Clean up on failure
del model
del trainer
clear_memory()
return float('inf')
# Create and run the study
study = optuna.create_study(
direction="minimize",
sampler=optuna.samplers.TPESampler(seed=42),
study_name="learning_rate_optimization"
)
study.optimize(objective, n_trials=num_trials)
# Print results
print(f"\nOptimization Results ({num_trials} trials):")
print("Best learning rate:", study.best_params["learning_rate"])
print("Best loss:", study.best_value)
print("\nAll trials:")
for trial in study.trials:
print(f"Learning rate: {trial.params['learning_rate']:.2e}, Loss: {trial.value:.4f}")
# Save results
import json
results = {
"best_learning_rate": study.best_params["learning_rate"],
"best_loss": study.best_value,
"all_trials": [(trial.params["learning_rate"], trial.value) for trial in study.trials]
}
with open("lr_optimization_results.json", "w") as f:
json.dump(results, f, indent=4)
# Plot optimization history
try:
fig = optuna.visualization.plot_optimization_history(study)
fig.show()
except Exception as e:
print(f"Could not create visualization: {e}")
# Add sophisticated final optimization using Gaussian Process Regression
def optimize_final_lr(study):
try:
# Extract learning rates and losses
X = np.array([[trial.params['learning_rate']] for trial in study.trials])
y = np.array([trial.value for trial in study.trials])
# Check if we have any valid results
valid_mask = np.isfinite(y)
if not np.any(valid_mask):
print("No valid trials found. Returning default learning rate.")
return {
'gpr_optimal_lr': 2e-5, # default fallback
'ei_optimal_lr': 2e-5,
'predicted_loss': float('inf'),
'uncertainty': float('inf')
}
# Filter out infinite values
X = X[valid_mask]
y = y[valid_mask]
# Ensure we have enough points for fitting
if len(X) < 2:
print("Not enough valid trials for GPR. Returning best observed value.")
best_idx = np.argmin(y)
return {
'gpr_optimal_lr': float(X[best_idx][0]),
'ei_optimal_lr': float(X[best_idx][0]),
'predicted_loss': float(y[best_idx]),
'uncertainty': float('inf')
}
# Transform to log space
X_log = np.log10(X)
# Normalize y values
y_mean = np.mean(y)
y_std = np.std(y)
if y_std == 0:
y_std = 1
y_normalized = (y - y_mean) / y_std
# Define kernel
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
# Fit Gaussian Process
gpr = GaussianProcessRegressor(
kernel=kernel,
n_restarts_optimizer=10,
random_state=42,
normalize_y=False # we're manually normalizing
)
try:
gpr.fit(X_log, y_normalized)
except np.linalg.LinAlgError:
print("GPR fitting failed. Returning best observed value.")
best_idx = np.argmin(y)
return {
'gpr_optimal_lr': float(X[best_idx][0]),
'ei_optimal_lr': float(X[best_idx][0]),
'predicted_loss': float(y[best_idx]),
'uncertainty': float('inf')
}
# Create fine grid of points for prediction
X_pred_log = np.linspace(np.log10(X.min()), np.log10(X.max()), 1000).reshape(-1, 1)
# Predict mean and std
y_pred_normalized, sigma = gpr.predict(X_pred_log, return_std=True)
# Denormalize predictions
y_pred = y_pred_normalized * y_std + y_mean
sigma = sigma * y_std
# Find the point with lowest predicted value
best_idx = np.argmin(y_pred)
optimal_lr = 10 ** X_pred_log[best_idx, 0]
# Calculate acquisition function (Expected Improvement)
best_f = np.min(y)
Z = (best_f - y_pred) / (sigma + 1e-9) # add small constant to prevent division by zero
ei = sigma * (Z * norm.cdf(Z) + norm.pdf(Z))
# Find point with highest expected improvement
ei_best_idx = np.argmax(ei)
ei_optimal_lr = 10 ** X_pred_log[ei_best_idx, 0]
return {
'gpr_optimal_lr': float(optimal_lr),
'ei_optimal_lr': float(ei_optimal_lr),
'predicted_loss': float(y_pred[best_idx]),
'uncertainty': float(sigma[best_idx])
}
except Exception as e:
print(f"Optimization failed with error: {e}")
return {
'gpr_optimal_lr': 2e-5, # default fallback
'ei_optimal_lr': 2e-5,
'predicted_loss': float('inf'),
'uncertainty': float('inf')
}
# Run final optimization and handle potential failures
try:
final_optimization = optimize_final_lr(study)
print("\nAdvanced Optimization Results:")
print(f"GPR Optimal Learning Rate: {final_optimization['gpr_optimal_lr']:.2e}")
print(f"Expected Improvement Optimal Learning Rate: {final_optimization['ei_optimal_lr']:.2e}")
print(f"Predicted Loss: {final_optimization['predicted_loss']:.4f}")
print(f"Uncertainty: {final_optimization['uncertainty']:.4f}")
except Exception as e:
print(f"Final optimization failed: {e}")
final_optimization = {
'gpr_optimal_lr': 2e-5,
'ei_optimal_lr': 2e-5,
'predicted_loss': float('inf'),
'uncertainty': float('inf')
}
# Save extended results
results.update({
"gpr_optimal_lr": float(final_optimization['gpr_optimal_lr']),
"ei_optimal_lr": float(final_optimization['ei_optimal_lr']),
"predicted_loss": float(final_optimization['predicted_loss']),
"uncertainty": float(final_optimization['uncertainty'])
})
# Visualization of the GPR results
def plot_gpr_results(study, final_optimization):
# Extract data and filter out infinite values
X = np.array([[trial.params['learning_rate']] for trial in study.trials])
y = np.array([trial.value for trial in study.trials])
# Create mask for finite values
finite_mask = np.isfinite(y)
X = X[finite_mask]
y = y[finite_mask]
# Check if we have enough valid points
if len(X) < 2:
print("Not enough valid points for GPR visualization")
return
# Create prediction points
X_pred = np.logspace(np.log10(X.min()), np.log10(X.max()), 100).reshape(-1, 1)
X_pred_log = np.log10(X_pred)
# Fit GPR for plotting
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42)
gpr.fit(np.log10(X), y)
# Predict mean and std
y_pred, sigma = gpr.predict(X_pred_log, return_std=True)
plt.figure(figsize=(12, 6))
plt.semilogx(X, y, 'ko', label='Valid Trials', markersize=8)
plt.semilogx(X_pred, y_pred, 'b-', label='GPR Mean')
plt.fill_between(X_pred.ravel(),
y_pred - 2*sigma,
y_pred + 2*sigma,
color='blue',
alpha=0.2,
label='95% Confidence')
# Only plot optimal lines if they are finite
if np.isfinite(final_optimization['gpr_optimal_lr']):
plt.axvline(final_optimization['gpr_optimal_lr'], color='r', linestyle='--',
label='GPR Optimal LR')
if np.isfinite(final_optimization['ei_optimal_lr']):
plt.axvline(final_optimization['ei_optimal_lr'], color='g', linestyle='--',
label='EI Optimal LR')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.title('Learning Rate Optimization Results with GPR')
plt.legend()
plt.grid(True)
plt.savefig('lr_optimization_plot.png', dpi=300, bbox_inches='tight')
plt.close()
plot_gpr_results(study, final_optimization)
# Save all results
with open("lr_optimization_results.json", "w") as f:
json.dump(results, f, indent=4)
# Store best learning rate as a variable for finetune.py to use
best_lr = study.best_params["learning_rate"]