Spaces:
Paused
Paused
from trl import SFTConfig | |
class Config: | |
def __init__(self): | |
# Model and training hyperparameters | |
self.BATCH_SIZE = 16 | |
self.EPOCHS = 3 | |
self.LEARNING_RATE = 2e-4 | |
self.MAX_SEQ_LENGTH = 512 | |
self.VOCAB_SIZE = 32000 | |
self.FP16 = True | |
self.WEIGHT_DECAY = 1e-3 | |
self.GRADIENT_ACCUMULATION_STEPS = self.BATCH_SIZE // 4 | |
# Dataset configurations | |
self.INPUT_DATASET = "HuggingFaceTB/smollm-corpus" | |
self.INSTRUCT_DATASET = "nroggendorff/elephant" | |
self.SHARD_SIZE = int(2e+5) | |
# Output and repo settings | |
self.OUTPUT_REPO = "nroggendorff/smallama" | |
self.PUSH_TO_HUB = True | |
self.INSTRUCT_FINETUNE_BOOL = False | |
# Training steps and warmup | |
self.FACTOR = 12 ** 3 // 2 | |
self.TOTAL_STEPS = (self.SHARD_SIZE * self.EPOCHS) // (self.BATCH_SIZE * self.GRADIENT_ACCUMULATION_STEPS) | |
self.WARMUP_STEPS = int(self.TOTAL_STEPS * 0.1) | |
# Initial state for shard offset | |
self.INIT = 0 | |
# ignore | |
self.getConfig = lambda: self._args() | |
# @staticmethod | |
def _args(self): | |
return SFTConfig( | |
output_dir="model", | |
num_train_epochs=self.EPOCHS, | |
per_device_train_batch_size=self.BATCH_SIZE, | |
learning_rate=self.LEARNING_RATE, | |
warmup_steps=self.WARMUP_STEPS, | |
weight_decay=self.WEIGHT_DECAY, | |
gradient_accumulation_steps=self.GRADIENT_ACCUMULATION_STEPS, | |
fp16=self.FP16, | |
save_steps=int(self.WARMUP_STEPS * 5), | |
logging_steps=int(self.WARMUP_STEPS), | |
save_total_limit=2, | |
report_to="none", | |
) |