|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import time |
|
|
|
from functools import partial |
|
from typing import Any, Dict, Optional, Tuple |
|
from warnings import warn |
|
|
|
import torch |
|
from omegaconf import DictConfig, ListConfig |
|
|
|
from torch import nn |
|
from torch.distributed import destroy_process_group, init_process_group |
|
from torch.distributed.fsdp import ( |
|
FullOptimStateDictConfig, |
|
FullStateDictConfig, |
|
FullyShardedDataParallel as FSDP, |
|
StateDictType, |
|
) |
|
from torch.optim import Optimizer |
|
from torch.utils.data import DataLoader, DistributedSampler |
|
from torchtune import config, modules, utils |
|
from torchtune.datasets import ConcatDataset |
|
from torchtune.modules.peft.peft_utils import ( |
|
get_adapter_params, |
|
get_merged_lora_ckpt, |
|
set_trainable_params, |
|
validate_state_dict_for_lora, |
|
) |
|
from torchtune.recipe_interfaces import FTRecipeInterface |
|
|
|
from tqdm import tqdm |
|
|
|
log = utils.get_logger("DEBUG") |
|
|
|
|
|
class LoRAFinetuneRecipeDistributed(FTRecipeInterface): |
|
""" |
|
Distributed LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports |
|
distributed training and can be run on a single node (1 to 8 GPUs). |
|
|
|
Features: |
|
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not |
|
supported. |
|
|
|
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` |
|
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep |
|
activations in memory and instead recompute them during the backward pass. This is especially |
|
helpful for larger batch sizes when you're memory constrained. But these savings in memory |
|
come at the cost of training performance. In most cases training can slow-down quite a bit as |
|
a result of this activation recomputation. |
|
|
|
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` |
|
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In |
|
most cases this should halve the memory footprint of full precision (fp32) training, without |
|
loss in model quality (will depend on the model, training data and other settings). For |
|
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 |
|
precision are currently not supported. |
|
|
|
- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is |
|
controlled using the ``gradient_accumulation_steps`` flag. |
|
|
|
Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. |
|
|
|
For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a |
|
total batch size of 64. |
|
|
|
Gradient accumulation is especially useful when you are memory constrained. In this case, |
|
accumulating gradients might give you better training speed than enabling activation |
|
checkpointing. |
|
|
|
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of |
|
training. Currently we checkpoint both the adapter weights (trainable params only) and the |
|
complete merged weights (adapter weights added back to the base model). For more details |
|
please take a look at our LoRA tutorial |
|
(https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). |
|
|
|
Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are |
|
only saved at the end of a given epoch and used in case of resuming training. Resuming |
|
training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is |
|
currently not supported. |
|
|
|
For more details on the checkpointer, please take a look at |
|
our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). |
|
|
|
- Logging. Terminal, Disk, WandB and TensorBoard are all supported. |
|
|
|
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config |
|
has example commands for how to kick-off training. |
|
|
|
Args: |
|
cfg (DictConfig): OmegaConf object parsed from yaml file |
|
|
|
Raises: |
|
ValueError: If ``dtype`` is set to fp16. |
|
ValueError: If world_size is 1 |
|
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. |
|
""" |
|
|
|
def __init__(self, cfg: DictConfig) -> None: |
|
self._device = utils.get_device(device=cfg.device) |
|
self._dtype = utils.get_dtype(cfg.dtype, device=self._device) |
|
|
|
if self._dtype == torch.float16: |
|
raise ValueError( |
|
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." |
|
) |
|
|
|
_, rank = utils.get_world_size_and_rank() |
|
|
|
|
|
|
|
self._is_rank_zero = rank == 0 |
|
|
|
|
|
self._output_dir = cfg.output_dir |
|
self._log_every_n_steps = cfg.get("log_every_n_steps", 1) |
|
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) |
|
|
|
|
|
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing |
|
|
|
|
|
|
|
self.seed = utils.set_seed(seed=cfg.seed) |
|
self.epochs_run = 0 |
|
self.total_epochs = cfg.epochs |
|
self.max_steps_per_epoch = cfg.max_steps_per_epoch |
|
self.global_step = 0 |
|
|
|
self._resume_from_checkpoint = cfg.resume_from_checkpoint |
|
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps |
|
|
|
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: |
|
""" |
|
Extract the checkpoint state from file and validate. This includes the |
|
base model weights. If resume_from_checkpoint is True, this also includes |
|
the adapter weights and recipe state |
|
""" |
|
self._checkpointer = config.instantiate( |
|
cfg_checkpointer, |
|
resume_from_checkpoint=self._resume_from_checkpoint, |
|
) |
|
checkpoint_dict = self._checkpointer.load_checkpoint() |
|
|
|
|
|
|
|
|
|
if self._resume_from_checkpoint: |
|
if utils.ADAPTER_KEY not in checkpoint_dict: |
|
raise ValueError( |
|
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided." |
|
) |
|
|
|
|
|
self._update_recipe_state(checkpoint_dict) |
|
return checkpoint_dict |
|
|
|
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: |
|
""" |
|
Updates the recipe state from checkpoint. |
|
""" |
|
if not ( |
|
utils.SEED_KEY in ckpt_dict |
|
and utils.TOTAL_EPOCHS_KEY in ckpt_dict |
|
and utils.MAX_STEPS_KEY in ckpt_dict |
|
): |
|
raise KeyError( |
|
"Checkpoint does not contain the required keys needed for updating recipe state." |
|
"Are you sure you passed in the right recipe checkpoint?" |
|
) |
|
|
|
|
|
if ( |
|
self.seed != ckpt_dict[utils.SEED_KEY] |
|
or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] |
|
or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY] |
|
): |
|
warn( |
|
message="""Configured value for seed, epochs or max_steps_per_epoch |
|
does not match the value stored in checkpoint.""" |
|
) |
|
self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) |
|
self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] |
|
self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] |
|
self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] |
|
|
|
def setup(self, cfg: DictConfig) -> None: |
|
""" |
|
Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), |
|
model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. |
|
""" |
|
if self._is_rank_zero: |
|
self._metric_logger = config.instantiate(cfg.metric_logger) |
|
|
|
|
|
self._metric_logger.log_config(cfg) |
|
|
|
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) |
|
|
|
self._model = self._setup_model( |
|
cfg_model=cfg.model, |
|
enable_activation_checkpointing=cfg.enable_activation_checkpointing, |
|
base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], |
|
lora_weights_state_dict=( |
|
checkpoint_dict[utils.ADAPTER_KEY] |
|
if self._resume_from_checkpoint |
|
else None |
|
), |
|
) |
|
self._tokenizer = config.instantiate(cfg.tokenizer) |
|
|
|
self._optimizer = self._setup_optimizer( |
|
cfg_optimizer=cfg.optimizer, |
|
opt_state_dict=checkpoint_dict[utils.OPT_KEY] |
|
if self._resume_from_checkpoint |
|
else None, |
|
) |
|
|
|
self._loss_fn = config.instantiate(cfg.loss) |
|
|
|
|
|
|
|
self._sampler, self._dataloader = self._setup_data( |
|
cfg_dataset=cfg.dataset, |
|
shuffle=cfg.shuffle, |
|
batch_size=cfg.batch_size, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._steps_per_epoch = ( |
|
len(self._dataloader) // self._gradient_accumulation_steps |
|
) |
|
if ( |
|
self.max_steps_per_epoch is not None |
|
and self.max_steps_per_epoch < self._steps_per_epoch |
|
): |
|
self._steps_per_epoch = self.max_steps_per_epoch |
|
self.global_step = self.epochs_run * self._steps_per_epoch |
|
|
|
|
|
|
|
self._lr_scheduler = self._setup_lr_scheduler( |
|
cfg_lr_scheduler=cfg.lr_scheduler, |
|
num_training_steps=self.total_epochs * self._steps_per_epoch, |
|
last_epoch=self.global_step - 1, |
|
) |
|
|
|
def _setup_model( |
|
self, |
|
cfg_model: DictConfig, |
|
enable_activation_checkpointing: bool, |
|
base_model_state_dict: Dict[str, Any], |
|
lora_weights_state_dict: Optional[Dict[str, Any]] = None, |
|
) -> nn.Module: |
|
""" |
|
Model initialization has some important considerations: |
|
a. To minimize GPU peak memory, we load the model on CPU with the right |
|
dtype. To ensure that we don't instantiate ``world_size`` number of models, |
|
we initialize on meta_device for all ranks other than rank 0. |
|
b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the |
|
model weights from checkpoint. |
|
c. While wrapping the model with FSDP, we set ``sync_module_states`` |
|
to TRUE and broadcast module params and buffers from rank 0. |
|
d. The ``device_id`` param ensures that the FSDP initialization happens on |
|
the correct device. |
|
""" |
|
|
|
if self._is_rank_zero: |
|
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") |
|
init_start = time.perf_counter() |
|
|
|
with utils.set_default_dtype(self._dtype): |
|
model = config.instantiate(cfg_model) |
|
|
|
log.info( |
|
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
validate_state_dict_for_lora( |
|
lora_attn_modules=cfg_model.lora_attn_modules, |
|
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, |
|
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), |
|
full_model_state_dict_keys=model.state_dict().keys(), |
|
lora_state_dict_keys=( |
|
lora_weights_state_dict.keys() |
|
if lora_weights_state_dict is not None |
|
else None |
|
), |
|
base_model_state_dict_keys=base_model_state_dict.keys(), |
|
) |
|
|
|
|
|
|
|
model.load_state_dict(base_model_state_dict, strict=False) |
|
if lora_weights_state_dict: |
|
model.load_state_dict(lora_weights_state_dict, strict=False) |
|
|
|
else: |
|
|
|
with utils.set_default_dtype(self._dtype), torch.device("meta"): |
|
model = config.instantiate(cfg_model) |
|
|
|
if self._dtype == torch.bfloat16: |
|
model = model.to(torch.bfloat16) |
|
|
|
|
|
self._lora_rank = cfg_model.lora_rank |
|
self._lora_alpha = cfg_model.lora_alpha |
|
|
|
|
|
self.adapter_params = get_adapter_params(model) |
|
set_trainable_params(model, self.adapter_params) |
|
|
|
model = FSDP( |
|
module=model, |
|
auto_wrap_policy=utils.lora_fsdp_wrap_policy( |
|
modules_to_wrap={modules.TransformerDecoderLayer} |
|
), |
|
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, |
|
device_id=self._device, |
|
|
|
mixed_precision=None, |
|
|
|
sync_module_states=True, |
|
|
|
param_init_fn=( |
|
lambda module: module.to_empty( |
|
device=torch.device("cuda"), recurse=False |
|
) |
|
if not self._is_rank_zero |
|
else None |
|
), |
|
) |
|
|
|
|
|
utils.validate_no_params_on_meta_device(model) |
|
|
|
if enable_activation_checkpointing: |
|
utils.set_activation_checkpointing( |
|
model, auto_wrap_policy={modules.TransformerDecoderLayer} |
|
) |
|
if self._is_rank_zero: |
|
memory_stats = utils.get_memory_stats(device=self._device) |
|
utils.log_memory_stats(memory_stats) |
|
|
|
|
|
torch.distributed.barrier() |
|
|
|
return model |
|
|
|
def _setup_optimizer( |
|
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None |
|
) -> Optimizer: |
|
optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) |
|
if opt_state_dict: |
|
|
|
|
|
opt_state_dict = utils.transform_opt_state_dict( |
|
opt_state_dict, self._model, optimizer |
|
) |
|
optimizer.load_state_dict(opt_state_dict) |
|
|
|
if self._is_rank_zero: |
|
log.info("Optimizer and loss are initialized.") |
|
return optimizer |
|
|
|
def _setup_lr_scheduler( |
|
self, |
|
cfg_lr_scheduler: DictConfig, |
|
num_training_steps: int, |
|
last_epoch: int, |
|
) -> Optimizer: |
|
lr_scheduler = config.instantiate( |
|
cfg_lr_scheduler, |
|
self._optimizer, |
|
num_training_steps=num_training_steps, |
|
last_epoch=last_epoch, |
|
) |
|
if self._is_rank_zero: |
|
log.info("Learning rate scheduler is initialized.") |
|
return lr_scheduler |
|
|
|
def _setup_data( |
|
self, |
|
cfg_dataset: DictConfig, |
|
shuffle: bool, |
|
batch_size: int, |
|
) -> Tuple[DistributedSampler, DataLoader]: |
|
""" |
|
All data related setup happens here. Currently this recipe only supports the |
|
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, |
|
iterable datasets and streaming datasets are not supported. |
|
""" |
|
world_size, rank = utils.get_world_size_and_rank() |
|
|
|
if isinstance(cfg_dataset, ListConfig): |
|
datasets = [ |
|
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) |
|
for single_cfg_dataset in cfg_dataset |
|
] |
|
ds = ConcatDataset(datasets=datasets) |
|
packed = False |
|
else: |
|
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) |
|
packed = cfg_dataset.get("packed", False) |
|
|
|
sampler = DistributedSampler( |
|
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 |
|
) |
|
|
|
dataloader = DataLoader( |
|
dataset=ds, |
|
batch_size=batch_size, |
|
sampler=sampler, |
|
collate_fn=partial( |
|
utils.padded_collate, |
|
padding_idx=self._tokenizer.pad_id, |
|
ignore_idx=self._loss_fn.ignore_index, |
|
) |
|
if not packed |
|
else None, |
|
) |
|
|
|
if self._is_rank_zero: |
|
log.info("Dataset and Sampler are initialized.") |
|
|
|
return sampler, dataloader |
|
|
|
def save_checkpoint( |
|
self, |
|
epoch: int, |
|
) -> None: |
|
""" |
|
Checkpoint the state of the recipe. The constructed checkpoint state dict |
|
contains the following information: |
|
- Merged weights with key MODEL_KEY |
|
- Adapter weights with key ADAPTER_KEY |
|
- Relevant recipe state if training is not complete |
|
|
|
Checkpointer will save the merged weights, adapter weights and recipe state in |
|
different checkpoint files. To correctly resume from training, the adapter weights |
|
and recipe state must be provided along with the base model weights. |
|
""" |
|
|
|
checkpoint_dict = {} |
|
|
|
intermediate_checkpoint = epoch + 1 < self.total_epochs |
|
|
|
|
|
with FSDP.state_dict_type( |
|
self._model, |
|
StateDictType.FULL_STATE_DICT, |
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), |
|
): |
|
cpu_state_dict = self._model.state_dict() |
|
if intermediate_checkpoint: |
|
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) |
|
else: |
|
opt_state_dict = None |
|
|
|
|
|
|
|
if self._is_rank_zero: |
|
|
|
|
|
|
|
adapter_key_filter = lambda x: x in self.adapter_params |
|
adapter_state_dict = { |
|
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) |
|
} |
|
checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) |
|
|
|
|
|
merged_state_dict = get_merged_lora_ckpt( |
|
cpu_state_dict, |
|
rank=self._lora_rank, |
|
alpha=self._lora_alpha, |
|
) |
|
checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict}) |
|
|
|
|
|
|
|
if intermediate_checkpoint: |
|
checkpoint_dict.update( |
|
{ |
|
utils.OPT_KEY: opt_state_dict, |
|
utils.SEED_KEY: self.seed, |
|
utils.EPOCHS_KEY: self.epochs_run, |
|
utils.TOTAL_EPOCHS_KEY: self.total_epochs, |
|
utils.MAX_STEPS_KEY: self.max_steps_per_epoch, |
|
} |
|
) |
|
|
|
self._checkpointer.save_checkpoint( |
|
checkpoint_dict, |
|
epoch=epoch, |
|
intermediate_checkpoint=intermediate_checkpoint, |
|
) |
|
|
|
def train(self) -> None: |
|
""" |
|
The core training loop. |
|
""" |
|
|
|
utils.cleanup_before_training() |
|
|
|
_, rank = utils.get_world_size_and_rank() |
|
|
|
|
|
self._optimizer.zero_grad() |
|
|
|
|
|
t0 = time.perf_counter() |
|
running_loss = 0 |
|
num_tokens = 0 |
|
|
|
|
|
for curr_epoch in range(self.epochs_run, self.total_epochs): |
|
|
|
|
|
|
|
self._sampler.set_epoch(curr_epoch) |
|
|
|
pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) |
|
for idx, batch in enumerate(self._dataloader): |
|
if ( |
|
self.max_steps_per_epoch is not None |
|
and (idx // self._gradient_accumulation_steps) |
|
== self.max_steps_per_epoch |
|
): |
|
break |
|
|
|
|
|
tokens, labels = batch["tokens"], batch["labels"] |
|
|
|
|
|
mask = batch.get("mask", None) |
|
input_pos = batch.get("input_pos", None) |
|
|
|
tokens = tokens.to(self._device) |
|
num_tokens += tokens.numel() |
|
labels = labels.to(self._device) |
|
mask = mask.to(self._device) if mask is not None else None |
|
input_pos = ( |
|
input_pos.to(self._device) if input_pos is not None else None |
|
) |
|
|
|
logits = self._model(tokens, mask=mask, input_pos=input_pos) |
|
|
|
logits = logits[..., :-1, :].contiguous() |
|
labels = labels[..., 1:].contiguous() |
|
logits = logits.transpose(1, 2) |
|
|
|
loss = self._loss_fn(logits, labels) |
|
|
|
loss = loss / self._gradient_accumulation_steps |
|
running_loss += loss |
|
loss.backward() |
|
|
|
|
|
if (idx + 1) % self._gradient_accumulation_steps == 0: |
|
self._optimizer.step() |
|
self._optimizer.zero_grad(set_to_none=True) |
|
self._lr_scheduler.step() |
|
|
|
|
|
self.global_step += 1 |
|
|
|
loss_to_log = running_loss.item() |
|
pbar.update(1) |
|
pbar.set_description( |
|
f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" |
|
) |
|
|
|
|
|
if ( |
|
self.global_step % self._log_every_n_steps == 0 |
|
and self._is_rank_zero |
|
): |
|
time_per_step = time.perf_counter() - t0 |
|
log_dict = { |
|
"loss": loss_to_log, |
|
"lr": self._optimizer.param_groups[0]["lr"], |
|
"tokens_per_second_per_gpu": num_tokens / time_per_step, |
|
} |
|
if self._log_peak_memory_stats: |
|
log_dict.update(utils.get_memory_stats(device=self._device)) |
|
self._metric_logger.log_dict( |
|
log_dict, |
|
step=self.global_step, |
|
) |
|
|
|
|
|
running_loss = 0 |
|
num_tokens = 0 |
|
t0 = time.perf_counter() |
|
|
|
self.epochs_run += 1 |
|
self.save_checkpoint(epoch=curr_epoch) |
|
|
|
def cleanup(self) -> None: |
|
if self._is_rank_zero: |
|
self._metric_logger.close() |
|
destroy_process_group() |
|
|
|
|
|
@config.parse |
|
def recipe_main(cfg: DictConfig) -> None: |
|
""" |
|
Entry point for the recipe. |
|
|
|
Configurable parameters are read in the following order: |
|
- Parameters specified in config (see available configs through ``tune ls``) |
|
- Overwritten by arguments from the command-line |
|
""" |
|
if not utils.is_distributed(): |
|
raise RuntimeError( |
|
"Distributed finetune recipe should be run via a distributed launcher." |
|
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" |
|
) |
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" |
|
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") |
|
|
|
config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) |
|
|
|
recipe = LoRAFinetuneRecipeDistributed(cfg=cfg) |
|
recipe.setup(cfg=cfg) |
|
recipe.train() |
|
recipe.cleanup() |
|
|
|
|
|
if __name__ == "__main__": |
|
sys.exit(recipe_main()) |
|
|