import contextlib import logging import math import warnings from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union from .utils import init_empty_weights log = logging.getLogger(__name__) def pop_config(cfg: DictConfig, key: str, must_exist: bool=True, default_value: Any=None, convert: bool=False) -> Any: """Pop a value from the main config file and return it. If the key does not exist, return the default_value or raise a RuntimeError depending on the must_exist flag. If the convert flag is set to True, then we will convert the value to a python object using OmegaConf.to_container. """ value = cfg.pop(key, None) if value is not None and convert: if not isinstance(value, DictConfig) and (not isinstance(value, ListConfig)): raise ValueError(f'The key {key} has a value of type {type(value)} that cannot be converted to a dict or list. Please check your yaml.') return om.to_container(value) elif value is not None: return value elif must_exist: raise NameError(f'The {key} parameter is missing and must exist for execution. Please check your yaml.') else: return default_value def calculate_batch_size_info(global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']]) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: if global_batch_size % dist.get_world_size() != 0: raise ValueError(f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.') device_batch_size = global_batch_size // dist.get_world_size() if device_microbatch_size == 'auto': device_grad_accum = 'auto' elif isinstance(device_microbatch_size, int): if device_microbatch_size > device_batch_size: log.warn(f'device_microbatch_size > device_batch_size, ' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.') device_microbatch_size = device_batch_size device_grad_accum = math.ceil(device_batch_size / device_microbatch_size) else: raise ValueError(f'Not sure how to parse device_microbatch_size={device_microbatch_size!r}') return (device_batch_size, device_microbatch_size, device_grad_accum) def update_batch_size_info(cfg: DictConfig) -> DictConfig: device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(cfg.global_train_batch_size, cfg.device_train_microbatch_size) cfg.n_gpus = dist.get_world_size() cfg.device_train_batch_size = device_train_batch_size cfg.device_train_microbatch_size = device_train_microbatch_size cfg.device_train_grad_accum = device_train_grad_accum if 'device_eval_batch_size' not in cfg: if cfg.device_train_microbatch_size == 'auto': cfg.device_eval_batch_size = 1 else: cfg.device_eval_batch_size = cfg.device_train_microbatch_size return cfg def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): init_context = contextlib.nullcontext() if 'init_device' in model_cfg: assert model_cfg.init_device in ['meta', 'cpu', 'mixed'] if fsdp_config is None and model_cfg.init_device == 'meta': warnings.warn("Using `cfg.model.init_device='meta'` is only valid when using FSDP! " + "Reverting to `cfg.model.init_device='cpu'`.") model_cfg.init_device = 'cpu' if model_cfg.init_device == 'meta': init_context = init_empty_weights() if model_cfg.init_device == 'mixed': if fsdp_config is None: raise NotImplementedError('Using init_device `mixed` is only supported with FSDP. ' + 'Please add a FSDP config.') if not fsdp_config.get('sync_module_states', False): warnings.warn('Setting `sync_module_states = True` for FSDP. This is required when using mixed initialization.') fsdp_config['sync_module_states'] = True fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) master_dtype = model_cfg.get('master_weights_dtype') small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16') if fsdp_config and master_dtype in small_dtypes: reduce_dtype = None buffer_dtype = None mixed_precision = fsdp_config.get('mixed_precision') if isinstance(mixed_precision, Mapping): reduce_dtype = mixed_precision.get('reduce_dtype') buffer_dtype = mixed_precision.get('buffer_dtype') fsdp_config['mixed_precision'] = {'param_dtype': None, 'reduce_dtype': reduce_dtype, 'buffer_dtype': buffer_dtype, 'keep_low_precision_grads': True} return init_context def log_config(cfg: DictConfig) -> None: """Logs the current config and updates the wandb and mlflow configs. This function can be called multiple times to update the wandb and MLflow config with different variables. """ print(om.to_yaml(cfg)) if 'wandb' in cfg.get('loggers', {}): try: import wandb except ImportError as e: raise e if wandb.run: wandb.config.update(om.to_container(cfg, resolve=True)) if 'mlflow' in cfg.get('loggers', {}): try: import mlflow except ImportError as e: raise e if mlflow.active_run(): mlflow.log_params(params=om.to_container(cfg, resolve=True))