Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
import sys | |
import json | |
import torch.distributed as dist | |
from os.path import dirname, join | |
from utils.config import Config | |
from utils.distributed import init_distributed_mode, is_main_process | |
from utils.logger import setup_logger | |
logger = logging.getLogger(__name__) | |
def setup_config(): | |
"""Conbine yaml config and command line config with OmegaConf. | |
Also converts types, e.g., `'None'` (str) --> `None` (None) | |
""" | |
config = Config.get_config() | |
if config.debug: | |
config.wandb.enable = False | |
return config | |
def setup_evaluate_config(config): | |
"""setup evaluation default settings, e.g., disable wandb""" | |
assert config.evaluate | |
config.wandb.enable = False | |
if config.output_dir is None: | |
config.output_dir = join(dirname(config.pretrained_path), "eval") | |
return config | |
def setup_output_dir(output_dir, excludes=["code"]): | |
"""ensure not overwritting an exisiting/non-empty output dir""" | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir, exist_ok=False) | |
else: | |
existing_dirs_files = os.listdir(output_dir) # list | |
remaining = set(existing_dirs_files) - set(excludes) | |
remaining = [e for e in remaining if "slurm" not in e] | |
remaining = [e for e in remaining if ".out" not in e] | |
# assert len(remaining) == 0, f"remaining dirs or files: {remaining}" | |
logger.warn(f"remaining dirs or files: {remaining}") | |
def setup_deepspeed_zero_config(stage): | |
# We currently set ZeRO based on stage: | |
if stage == 1: | |
return {"stage": 1, "reduce_bucket_size": 5e8} | |
if stage == 2: | |
return { | |
"stage": 2, | |
"contiguous_gradients": False, | |
"overlap_comm": False, | |
"reduce_scatter": True, | |
"reduce_bucket_size": 5e8, | |
"allgather_bucket_size": 5e8, | |
"offload_optimizer": { | |
"device": "cpu" | |
}, | |
} | |
if stage == 3: | |
return { | |
"stage": 3, | |
"contiguous_gradients": True, | |
"stage3_max_live_parameters": 1e9, | |
"stage3_max_reuse_distance": 1e9, | |
"stage3_prefetch_bucket_size": 1e7, | |
"stage3_param_persistence_threshold": 1e5, | |
"reduce_bucket_size": 1e7, | |
"sub_group_size": 1e9, | |
"offload_optimizer": { | |
"device": "cpu" | |
}, | |
"offload_param": { | |
"device": "cpu" | |
} | |
} | |
raise ValueError("Wrong stage for deepspeed {}".format(stage.stage)) | |
def setup_deepspeed_config(config): | |
config.deepspeed_config = os.path.join(config.output_dir, "deepspeed_config.json") | |
opts = config.optimizer | |
logger.info(f'Write deepspeed config to {config.deepspeed_config}') | |
if not is_main_process(): | |
return config | |
os.makedirs(config.output_dir, exist_ok=True) | |
with open(config.deepspeed_config, mode="w") as writer: | |
ds_config = { | |
"train_batch_size": config.batch_size * dist.get_world_size(), | |
"train_micro_batch_size_per_gpu": config.batch_size, | |
"steps_per_print": 100, | |
"optimizer": { | |
"type": "Adam", | |
"adam_w_mode": True, | |
"params": { | |
"lr": opts.lr, | |
"weight_decay": opts.weight_decay, | |
"bias_correction": True, | |
"betas": [ | |
opts.opt_betas[0], | |
opts.opt_betas[1], | |
], | |
"eps": 1e-8 | |
} | |
} | |
} | |
if config.deepspeed.stage != 0: | |
ds_config["zero_optimization"] = setup_deepspeed_zero_config(config.deepspeed.stage) | |
if config.use_half_precision: | |
if config.get('use_bf16', False): | |
ds_config["bf16"] = { | |
"enabled": True | |
} | |
else: | |
ds_config["fp16"] = { | |
"enabled": True, | |
"auto_cast": False, | |
"loss_scale": 0, | |
"initial_scale_power": 16, | |
"loss_scale_window": 1000, | |
"hysteresis": 2, | |
"consecutive_hysteresis": False, | |
"min_loss_scale": 1 | |
} | |
else: | |
assert config.deepspeed.stage == 0, "You must use fp16 or bf16 when using ZERO!!!" | |
# if config.get("max_grad_norm", -1) > 0: | |
# ds_config.update({"gradient_clipping", config.max_grad_norm}) | |
if opts.get("max_grad_norm", -1) > 0: | |
ds_config["gradient_clipping"] = opts.max_grad_norm | |
writer.write(json.dumps(ds_config, indent=2)) | |
return config | |
def setup_main(): | |
""" | |
Setup config, logger, output_dir, etc. | |
Shared for pretrain and all downstream tasks. | |
""" | |
# try: | |
config = setup_config() | |
if hasattr(config, "evaluate") and config.evaluate: | |
config = setup_evaluate_config(config) | |
init_distributed_mode(config) | |
if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
config = setup_deepspeed_config(config) | |
# except Exception as e: | |
# print(f"\033[31m NODE NAME: {os.environ['SLURMD_NODENAME']} is not OK \033[0m") | |
# logger.info(f"NODE NAME: {os.environ['SLURMD_NODENAME']} is not OK") | |
# raise ValueError | |
if is_main_process(): | |
setup_output_dir(config.output_dir, excludes=["code"]) | |
setup_logger(output=config.output_dir, color=True, name="vindlu") | |
logger.info(f"config: {Config.pretty_text(config)}") | |
Config.dump(config, os.path.join(config.output_dir, "config.json")) | |
dist.barrier() | |
return config | |