train-llama / config.txt
nroggendorff's picture
Create config.txt
6a1ce4a verified
raw
history blame
1.73 kB
from trl import SFTConfig
class Config:
def __init__(self):
# Model and training hyperparameters
self.BATCH_SIZE = 16
self.EPOCHS = 3
self.LEARNING_RATE = 2e-4
self.MAX_SEQ_LENGTH = 512
self.VOCAB_SIZE = 32000
self.FP16 = True
self.WEIGHT_DECAY = 1e-3
self.GRADIENT_ACCUMULATION_STEPS = self.BATCH_SIZE // 4
# Dataset configurations
self.INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
self.INSTRUCT_DATASET = "nroggendorff/elephant"
self.SHARD_SIZE = int(2e+5)
# Output and repo settings
self.OUTPUT_REPO = "nroggendorff/smallama"
self.PUSH_TO_HUB = True
self.INSTRUCT_FINETUNE_BOOL = False
# Training steps and warmup
self.FACTOR = 12 ** 3 // 2
self.TOTAL_STEPS = (self.SHARD_SIZE * self.EPOCHS) // (self.BATCH_SIZE * self.GRADIENT_ACCUMULATION_STEPS)
self.WARMUP_STEPS = int(self.TOTAL_STEPS * 0.1)
# Initial state for shard offset
self.INIT = 0
# ignore
self.getConfig = lambda: self._args()
# @staticmethod
def _args(self):
return SFTConfig(
output_dir="model",
num_train_epochs=self.EPOCHS,
per_device_train_batch_size=self.BATCH_SIZE,
learning_rate=self.LEARNING_RATE,
warmup_steps=self.WARMUP_STEPS,
weight_decay=self.WEIGHT_DECAY,
gradient_accumulation_steps=self.GRADIENT_ACCUMULATION_STEPS,
fp16=self.FP16,
save_steps=int(self.WARMUP_STEPS * 5),
logging_steps=int(self.WARMUP_STEPS),
save_total_limit=2,
report_to="none",
)