from trl import SFTConfig class Config: def __init__(self): # Model and training hyperparameters self.BATCH_SIZE = 16 self.EPOCHS = 3 self.LEARNING_RATE = 2e-4 self.MAX_SEQ_LENGTH = 512 self.VOCAB_SIZE = 32000 self.FP16 = True self.WEIGHT_DECAY = 1e-3 self.GRADIENT_ACCUMULATION_STEPS = self.BATCH_SIZE // 4 # Dataset configurations self.INPUT_DATASET = "HuggingFaceTB/smollm-corpus" self.INSTRUCT_DATASET = "nroggendorff/elephant" self.SHARD_SIZE = int(2e+5) # Output and repo settings self.OUTPUT_REPO = "nroggendorff/smallama" self.PUSH_TO_HUB = True self.INSTRUCT_FINETUNE_BOOL = False # Training steps and warmup self.FACTOR = 12 ** 3 // 2 self.TOTAL_STEPS = (self.SHARD_SIZE * self.EPOCHS) // (self.BATCH_SIZE * self.GRADIENT_ACCUMULATION_STEPS) self.WARMUP_STEPS = int(self.TOTAL_STEPS * 0.1) # Initial state for shard offset self.INIT = 0 # ignore self.getConfig = lambda: self._args() # @staticmethod def _args(self): return SFTConfig( output_dir="model", num_train_epochs=self.EPOCHS, per_device_train_batch_size=self.BATCH_SIZE, learning_rate=self.LEARNING_RATE, warmup_steps=self.WARMUP_STEPS, weight_decay=self.WEIGHT_DECAY, gradient_accumulation_steps=self.GRADIENT_ACCUMULATION_STEPS, fp16=self.FP16, save_steps=int(self.WARMUP_STEPS * 5), logging_steps=int(self.WARMUP_STEPS), save_total_limit=2, report_to="none", )