Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import torch | |
import json | |
import warnings | |
import omegaconf | |
from omegaconf import OmegaConf | |
from sat.helpers import print_rank0 | |
from sat import mpu | |
from sat.arguments import set_random_seed | |
from sat.arguments import add_training_args, add_evaluation_args, add_data_args | |
import torch.distributed | |
def add_model_config_args(parser): | |
"""Model arguments""" | |
group = parser.add_argument_group("model", "model configuration") | |
group.add_argument("--base", type=str, nargs="*", help="config for input and saving") | |
group.add_argument( | |
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert." | |
) | |
group.add_argument("--force-pretrain", action="store_true") | |
group.add_argument("--device", type=int, default=-1) | |
group.add_argument("--debug", action="store_true") | |
group.add_argument("--log-image", type=bool, default=True) | |
return parser | |
def add_sampling_config_args(parser): | |
"""Sampling configurations""" | |
group = parser.add_argument_group("sampling", "Sampling Configurations") | |
group.add_argument("--output-dir", type=str, default="samples") | |
group.add_argument("--input-dir", type=str, default=None) | |
group.add_argument("--input-type", type=str, default="cli") | |
group.add_argument("--input-file", type=str, default="input.txt") | |
group.add_argument("--final-size", type=int, default=2048) | |
group.add_argument("--sdedit", action="store_true") | |
group.add_argument("--grid-num-rows", type=int, default=1) | |
group.add_argument("--force-inference", action="store_true") | |
group.add_argument("--lcm_steps", type=int, default=None) | |
group.add_argument("--sampling-num-frames", type=int, default=32) | |
group.add_argument("--sampling-fps", type=int, default=8) | |
group.add_argument("--only-save-latents", type=bool, default=False) | |
group.add_argument("--only-log-video-latents", type=bool, default=False) | |
group.add_argument("--latent-channels", type=int, default=32) | |
group.add_argument("--image2video", action="store_true") | |
return parser | |
def get_args(args_list=None, parser=None): | |
"""Parse all the args.""" | |
if parser is None: | |
parser = argparse.ArgumentParser(description="sat") | |
else: | |
assert isinstance(parser, argparse.ArgumentParser) | |
parser = add_model_config_args(parser) | |
parser = add_sampling_config_args(parser) | |
parser = add_training_args(parser) | |
parser = add_evaluation_args(parser) | |
parser = add_data_args(parser) | |
import deepspeed | |
parser = deepspeed.add_config_arguments(parser) | |
args = parser.parse_args(args_list) | |
args = process_config_to_args(args) | |
if not args.train_data: | |
print_rank0("No training data specified", level="WARNING") | |
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set." | |
if args.train_iters is None and args.epochs is None: | |
args.train_iters = 10000 # default 10k iters | |
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING") | |
args.cuda = torch.cuda.is_available() | |
args.rank = int(os.getenv("RANK", "0")) | |
args.world_size = int(os.getenv("WORLD_SIZE", "1")) | |
if args.local_rank is None: | |
args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun | |
if args.device == -1: | |
if torch.cuda.device_count() == 0: | |
args.device = "cpu" | |
elif args.local_rank is not None: | |
args.device = args.local_rank | |
else: | |
args.device = args.rank % torch.cuda.device_count() | |
if args.local_rank != args.device and args.mode != "inference": | |
raise ValueError( | |
"LOCAL_RANK (default 0) and args.device inconsistent. " | |
"This can only happens in inference mode. " | |
"Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. " | |
) | |
if args.rank == 0: | |
print_rank0("using world size: {}".format(args.world_size)) | |
if args.train_data_weights is not None: | |
assert len(args.train_data_weights) == len(args.train_data) | |
if args.mode != "inference": # training with deepspeed | |
args.deepspeed = True | |
if args.deepspeed_config is None: # not specified | |
deepspeed_config_path = os.path.join( | |
os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json" | |
) | |
with open(deepspeed_config_path) as file: | |
args.deepspeed_config = json.load(file) | |
override_deepspeed_config = True | |
else: | |
override_deepspeed_config = False | |
assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16." | |
if args.zero_stage > 0 and not args.fp16 and not args.bf16: | |
print_rank0("Automatically set fp16=True to use ZeRO.") | |
args.fp16 = True | |
args.bf16 = False | |
if args.deepspeed: | |
if args.checkpoint_activations: | |
args.deepspeed_activation_checkpointing = True | |
else: | |
args.deepspeed_activation_checkpointing = False | |
if args.deepspeed_config is not None: | |
deepspeed_config = args.deepspeed_config | |
if override_deepspeed_config: # not specify deepspeed_config, use args | |
if args.fp16: | |
deepspeed_config["fp16"]["enabled"] = True | |
elif args.bf16: | |
deepspeed_config["bf16"]["enabled"] = True | |
deepspeed_config["fp16"]["enabled"] = False | |
else: | |
deepspeed_config["fp16"]["enabled"] = False | |
deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size | |
deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps | |
optimizer_params_config = deepspeed_config["optimizer"]["params"] | |
optimizer_params_config["lr"] = args.lr | |
optimizer_params_config["weight_decay"] = args.weight_decay | |
else: # override args with values in deepspeed_config | |
if args.rank == 0: | |
print_rank0("Will override arguments with manually specified deepspeed_config!") | |
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: | |
args.fp16 = True | |
else: | |
args.fp16 = False | |
if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]: | |
args.bf16 = True | |
else: | |
args.bf16 = False | |
if "train_micro_batch_size_per_gpu" in deepspeed_config: | |
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] | |
if "gradient_accumulation_steps" in deepspeed_config: | |
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] | |
else: | |
args.gradient_accumulation_steps = None | |
if "optimizer" in deepspeed_config: | |
optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) | |
args.lr = optimizer_params_config.get("lr", args.lr) | |
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) | |
args.deepspeed_config = deepspeed_config | |
# initialize distributed and random seed because it always seems to be necessary. | |
initialize_distributed(args) | |
args.seed = args.seed + mpu.get_data_parallel_rank() | |
set_random_seed(args.seed) | |
return args | |
def initialize_distributed(args): | |
"""Initialize torch.distributed.""" | |
if torch.distributed.is_initialized(): | |
if mpu.model_parallel_is_initialized(): | |
if args.model_parallel_size != mpu.get_model_parallel_world_size(): | |
raise ValueError( | |
"model_parallel_size is inconsistent with prior configuration." | |
"We currently do not support changing model_parallel_size." | |
) | |
return False | |
else: | |
if args.model_parallel_size > 1: | |
warnings.warn( | |
"model_parallel_size > 1 but torch.distributed is not initialized via SAT." | |
"Please carefully make sure the correctness on your own." | |
) | |
mpu.initialize_model_parallel(args.model_parallel_size) | |
return True | |
# the automatic assignment of devices has been moved to arguments.py | |
if args.device == "cpu": | |
pass | |
else: | |
torch.cuda.set_device(args.device) | |
# Call the init process | |
init_method = "tcp://" | |
args.master_ip = os.getenv("MASTER_ADDR", "localhost") | |
if args.world_size == 1: | |
from sat.helpers import get_free_port | |
default_master_port = str(get_free_port()) | |
else: | |
default_master_port = "6000" | |
args.master_port = os.getenv("MASTER_PORT", default_master_port) | |
init_method += args.master_ip + ":" + args.master_port | |
torch.distributed.init_process_group( | |
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method | |
) | |
# Set the model-parallel / data-parallel communicators. | |
mpu.initialize_model_parallel(args.model_parallel_size) | |
# Set vae context parallel group equal to model parallel group | |
from sgm.util import set_context_parallel_group, initialize_context_parallel | |
if args.model_parallel_size <= 2: | |
set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group()) | |
else: | |
initialize_context_parallel(2) | |
# mpu.initialize_model_parallel(1) | |
# Optional DeepSpeed Activation Checkpointing Features | |
if args.deepspeed: | |
import deepspeed | |
deepspeed.init_distributed( | |
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method | |
) | |
# # It seems that it has no negative influence to configure it even without using checkpointing. | |
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) | |
else: | |
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout. | |
try: | |
import deepspeed | |
from deepspeed.runtime.activation_checkpointing.checkpointing import ( | |
_CUDA_RNG_STATE_TRACKER, | |
_MODEL_PARALLEL_RNG_TRACKER_NAME, | |
) | |
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1 | |
except Exception as e: | |
from sat.helpers import print_rank0 | |
print_rank0(str(e), level="DEBUG") | |
return True | |
def process_config_to_args(args): | |
"""Fetch args from only --base""" | |
configs = [OmegaConf.load(cfg) for cfg in args.base] | |
config = OmegaConf.merge(*configs) | |
args_config = config.pop("args", OmegaConf.create()) | |
for key in args_config: | |
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig): | |
arg = OmegaConf.to_object(args_config[key]) | |
else: | |
arg = args_config[key] | |
if hasattr(args, key): | |
setattr(args, key, arg) | |
if "model" in config: | |
model_config = config.pop("model", OmegaConf.create()) | |
args.model_config = model_config | |
if "deepspeed" in config: | |
deepspeed_config = config.pop("deepspeed", OmegaConf.create()) | |
args.deepspeed_config = OmegaConf.to_object(deepspeed_config) | |
if "data" in config: | |
data_config = config.pop("data", OmegaConf.create()) | |
args.data_config = data_config | |
return args | |