|
import contextlib |
|
import logging |
|
import math |
|
import warnings |
|
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union |
|
from .utils import init_empty_weights |
|
log = logging.getLogger(__name__) |
|
|
|
def pop_config(cfg: DictConfig, key: str, must_exist: bool=True, default_value: Any=None, convert: bool=False) -> Any: |
|
"""Pop a value from the main config file and return it. |
|
|
|
If the key does not exist, return the default_value or raise a RuntimeError |
|
depending on the must_exist flag. If the convert flag is set to True, then |
|
we will convert the value to a python object using OmegaConf.to_container. |
|
""" |
|
value = cfg.pop(key, None) |
|
if value is not None and convert: |
|
if not isinstance(value, DictConfig) and (not isinstance(value, ListConfig)): |
|
raise ValueError(f'The key {key} has a value of type {type(value)} that cannot be converted to a dict or list. Please check your yaml.') |
|
return om.to_container(value) |
|
elif value is not None: |
|
return value |
|
elif must_exist: |
|
raise NameError(f'The {key} parameter is missing and must exist for execution. Please check your yaml.') |
|
else: |
|
return default_value |
|
|
|
def calculate_batch_size_info(global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']]) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: |
|
if global_batch_size % dist.get_world_size() != 0: |
|
raise ValueError(f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.') |
|
device_batch_size = global_batch_size // dist.get_world_size() |
|
if device_microbatch_size == 'auto': |
|
device_grad_accum = 'auto' |
|
elif isinstance(device_microbatch_size, int): |
|
if device_microbatch_size > device_batch_size: |
|
log.warn(f'device_microbatch_size > device_batch_size, ' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.') |
|
device_microbatch_size = device_batch_size |
|
device_grad_accum = math.ceil(device_batch_size / device_microbatch_size) |
|
else: |
|
raise ValueError(f'Not sure how to parse device_microbatch_size={device_microbatch_size!r}') |
|
return (device_batch_size, device_microbatch_size, device_grad_accum) |
|
|
|
def update_batch_size_info(cfg: DictConfig) -> DictConfig: |
|
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(cfg.global_train_batch_size, cfg.device_train_microbatch_size) |
|
cfg.n_gpus = dist.get_world_size() |
|
cfg.device_train_batch_size = device_train_batch_size |
|
cfg.device_train_microbatch_size = device_train_microbatch_size |
|
cfg.device_train_grad_accum = device_train_grad_accum |
|
if 'device_eval_batch_size' not in cfg: |
|
if cfg.device_train_microbatch_size == 'auto': |
|
cfg.device_eval_batch_size = 1 |
|
else: |
|
cfg.device_eval_batch_size = cfg.device_train_microbatch_size |
|
return cfg |
|
|
|
def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): |
|
init_context = contextlib.nullcontext() |
|
if 'init_device' in model_cfg: |
|
assert model_cfg.init_device in ['meta', 'cpu', 'mixed'] |
|
if fsdp_config is None and model_cfg.init_device == 'meta': |
|
warnings.warn("Using `cfg.model.init_device='meta'` is only valid when using FSDP! " + "Reverting to `cfg.model.init_device='cpu'`.") |
|
model_cfg.init_device = 'cpu' |
|
if model_cfg.init_device == 'meta': |
|
init_context = init_empty_weights() |
|
if model_cfg.init_device == 'mixed': |
|
if fsdp_config is None: |
|
raise NotImplementedError('Using init_device `mixed` is only supported with FSDP. ' + 'Please add a FSDP config.') |
|
if not fsdp_config.get('sync_module_states', False): |
|
warnings.warn('Setting `sync_module_states = True` for FSDP. This is required when using mixed initialization.') |
|
fsdp_config['sync_module_states'] = True |
|
fsdp_config.setdefault('use_orig_params', False) |
|
fsdp_config.setdefault('load_monolith_rank0_only', True) |
|
master_dtype = model_cfg.get('master_weights_dtype') |
|
small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16') |
|
if fsdp_config and master_dtype in small_dtypes: |
|
reduce_dtype = None |
|
buffer_dtype = None |
|
mixed_precision = fsdp_config.get('mixed_precision') |
|
if isinstance(mixed_precision, Mapping): |
|
reduce_dtype = mixed_precision.get('reduce_dtype') |
|
buffer_dtype = mixed_precision.get('buffer_dtype') |
|
fsdp_config['mixed_precision'] = {'param_dtype': None, 'reduce_dtype': reduce_dtype, 'buffer_dtype': buffer_dtype, 'keep_low_precision_grads': True} |
|
return init_context |
|
|
|
def log_config(cfg: DictConfig) -> None: |
|
"""Logs the current config and updates the wandb and mlflow configs. |
|
|
|
This function can be called multiple times to update the wandb and MLflow |
|
config with different variables. |
|
""" |
|
print(om.to_yaml(cfg)) |
|
if 'wandb' in cfg.get('loggers', {}): |
|
try: |
|
import wandb |
|
except ImportError as e: |
|
raise e |
|
if wandb.run: |
|
wandb.config.update(om.to_container(cfg, resolve=True)) |
|
if 'mlflow' in cfg.get('loggers', {}): |
|
try: |
|
import mlflow |
|
except ImportError as e: |
|
raise e |
|
if mlflow.active_run(): |
|
mlflow.log_params(params=om.to_container(cfg, resolve=True)) |