# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Train a GAN using the techniques described in the paper "Alias-Free Generative Adversarial Networks".""" import os import click import re import json import tempfile import torch import dnnlib from training import training_loop from metrics import metric_main from torch_utils import training_stats from torch_utils import custom_ops from torch_utils import gen_utils # ---------------------------------------------------------------------------- def subprocess_fn(rank, c, temp_dir): dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) # Init torch.distributed. if c.num_gpus > 1: init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) if os.name == 'nt': init_method = 'file:///' + init_file.replace('\\', '/') torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=c.num_gpus) else: init_method = f'file://{init_file}' torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus) # Init torch_utils. sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) if rank != 0: custom_ops.verbosity = 'none' # Execute training loop. training_loop.training_loop(rank=rank, **c) # ---------------------------------------------------------------------------- def launch_training(c, desc, outdir, dry_run): dnnlib.util.Logger(should_flush=True) # Pick output directory. c.run_dir = gen_utils.make_run_dir(outdir=outdir, desc=desc, dry_run=dry_run) # Print options. print() print('Training options:') print(json.dumps(c, indent=2)) print() print(f'Output directory: {c.run_dir}') print(f'Number of GPUs: {c.num_gpus}') print(f'Batch size: {c.batch_size} images') print(f'Training duration: {c.total_kimg} kimg') print(f'Dataset path: {c.training_set_kwargs.path}') print(f'Dataset size: {c.training_set_kwargs.max_size} images') print(f'Dataset resolution: {c.training_set_kwargs.resolution}') print(f'Dataset labels: {c.training_set_kwargs.use_labels}') print(f'Dataset x-flips: {c.training_set_kwargs.xflip}') print(f'Dataset y-flips: {c.training_set_kwargs.yflip}') print() # Dry run? if dry_run: print('Dry run; exiting.') return with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: json.dump(c, f, indent=2) # Launch processes. print('Launching processes...') torch.multiprocessing.set_start_method('spawn') with tempfile.TemporaryDirectory() as temp_dir: if c.num_gpus == 1: subprocess_fn(rank=0, c=c, temp_dir=temp_dir) else: torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus) # ---------------------------------------------------------------------------- def init_dataset_kwargs(data): try: dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False, yflip=False) dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset. dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution. dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels. dataset_kwargs.max_size = len(dataset_obj) # Be explicit about dataset size. return dataset_kwargs, dataset_obj.name except IOError as err: raise click.ClickException(f'--data: {err}') # ---------------------------------------------------------------------------- def parse_comma_separated_list(s): if isinstance(s, list): return s if s is None or s.lower() == 'none' or s == '': return [] return s.split(',') # ---------------------------------------------------------------------------- @click.command() # Required. @click.option('--cfg', help='Base configuration', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2', 'stylegan2-ext']), required=True) @click.option('--data', help='Training data', metavar='[ZIP|DIR]', type=click.Path(exists=True, dir_okay=False), required=True) @click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True) @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True) # Optional features. @click.option('--cond', help='Train conditional model', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--mirror', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--mirror-y', help='Enable dataset y-flips', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--aug', help='Augmentation mode', type=click.Choice(['noaug', 'ada', 'fixed']), default='ada', show_default=True) @click.option('--augpipe', help='Augmentation pipeline', type=click.Choice(['blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc', 'bgcf', 'bgcfn', 'bgcfnc']), default='bgc', show_default=True) @click.option('--resume', help='Resume from given network pickle', metavar='[PATH|URL]', type=str) @click.option('--initstrength', help='Override ADA augment strength at the beginning', type=click.FloatRange(min=0.0), ) @click.option('--freezeD', help='Freeze first layers of D', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) # Experimental features. @click.option('--freezeM', help='Freeze first layers of the Mapping Network Gm', type=click.IntRange(min=0), default=0, show_default=True) @click.option('--freezeE', help='Freeze the embedding layer for conditional models', type=bool, default=False, show_default=True) @click.option('--blur-percent', help='Blur all images for the first % of training (before D)', type=click.FloatRange(min=0.0, max=100.0), default=0.0, show_default=True) # Misc hyperparameters. @click.option('--gamma', help='R1 regularization weight', metavar='FLOAT', type=click.FloatRange(min=0), default=None) @click.option('--p', help='Probability for --aug=fixed', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.2, show_default=True) @click.option('--target', help='Target value for --aug=ada', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.6, show_default=True) @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) @click.option('--cbase', help='Capacity multiplier', metavar='INT', type=click.IntRange(min=1), default=32768, show_default=True) @click.option('--cmax', help='Max. feature maps', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) @click.option('--glr', help='G learning rate [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0)) @click.option('--dlr', help='D learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=0.002, show_default=True) @click.option('--map-depth', help='Mapping network depth [default: varies]', metavar='INT', type=click.IntRange(min=1)) @click.option('--mbstd-group', help='Minibatch std group size', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True) # Misc settings. @click.option('--outdir', help='Where to save the results', metavar='DIR', type=click.Path(file_okay=False), default=os.path.join(os.getcwd(), 'training-runs')) @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='none', show_default=True) @click.option('--kimg', help='Total training duration', metavar='KIMG', type=click.IntRange(min=1), default=25000, show_default=True) @click.option('--resume-kimg', help='Number of kimg images to resume from', metavar='RKIMG', type=click.IntRange(min=0), default=0, show_default=True) @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=4, show_default=True) @click.option('--snap', help='How often to save model snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) @click.option('--img-snap', help='How often to save image snapshots', metavar='INT', type=click.IntRange(min=1), default=50, show_default=True) @click.option('--snap-res', help='Screen resolution to save image snapshot', type=click.Choice(['1080p', '4k', '8k']), default='4k', show_default=True) @click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True) @click.option('--fp32', help='Disable mixed-precision', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--nobench', help='Disable cuDNN benchmarking', metavar='BOOL', type=bool, default=False, show_default=True) @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=3, show_default=True) @click.option('-n','--dry-run', help='Print training options and exit', is_flag=True) def main(**kwargs): """Train a GAN using the techniques described in the paper "Alias-Free Generative Adversarial Networks". Examples: \b # Train StyleGAN3-T for AFHQv2 using 8 GPUs. python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \\ --gpus=8 --batch=32 --gamma=8.2 --mirror=1 \b # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle. python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \\ --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \\ --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl \b # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs. python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \\ --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug """ # Initialize config. opts = dnnlib.EasyDict(kwargs) # Command line arguments. c = dnnlib.EasyDict() # Main config dict. c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict()) c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) # TODO: Use ComplexSGD: https://arxiv.org/abs/2102.08431 c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss') c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2) # Training set. c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data) if opts.cond and not c.training_set_kwargs.use_labels: raise click.ClickException('--cond=True requires labels specified in dataset.json') c.training_set_kwargs.use_labels = opts.cond c.training_set_kwargs.xflip = opts.mirror c.training_set_kwargs.yflip = opts.mirror_y # Hyperparameters & settings. c.num_gpus = opts.gpus c.batch_size = opts.batch c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax c.G_kwargs.mapping_kwargs.num_layers = (8 if opts.cfg == 'stylegan2' else 2) if opts.map_depth is None else opts.map_depth c.G_kwargs.mapping_kwargs.freeze_layers = opts.freezem c.G_kwargs.mapping_kwargs.freeze_embed = opts.freezee c.D_kwargs.block_kwargs.freeze_layers = opts.freezed c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group if opts.gamma is not None: c.loss_kwargs.r1_gamma = float(opts.gamma) else: # Use heuristic from StyleGAN2-ADA c.loss_kwargs.r1_gamma = 0.0002 * c.training_set_kwargs.resolution ** 2 / c.batch_size print(f'Using heuristic, R1 gamma set at: {c.loss_kwargs.r1_gamma:.4f}') c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr c.D_opt_kwargs.lr = opts.dlr c.metrics = opts.metrics c.total_kimg = opts.kimg c.resume_kimg = opts.resume_kimg c.kimg_per_tick = opts.tick c.network_snapshot_ticks = opts.snap c.image_snapshot_ticks = opts.img_snap c.snap_res = opts.snap_res c.random_seed = c.training_set_kwargs.random_seed = opts.seed c.data_loader_kwargs.num_workers = opts.workers # Sanity checks. if c.batch_size % c.num_gpus != 0: raise click.ClickException('--batch must be a multiple of --gpus') if c.batch_size % (c.num_gpus * c.batch_gpu) != 0: raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu') if c.batch_gpu < c.D_kwargs.epilogue_kwargs.mbstd_group_size: raise click.ClickException('--batch-gpu cannot be smaller than --mbstd') if any(not metric_main.is_valid_metric(metric) for metric in c.metrics): raise click.ClickException('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) # Base configuration. c.ema_kimg = c.batch_size * 10 / 32 if opts.cfg == 'stylegan2': c.G_kwargs.class_name = 'training.networks_stylegan2.Generator' c.loss_kwargs.style_mixing_prob = 0.9 # Enable style mixing regularization. c.loss_kwargs.pl_weight = 2 # Enable path length regularization. c.G_reg_interval = 4 # Enable lazy regularization for G. c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions. c.loss_kwargs.pl_no_weight_grad = True # Speed up path length regularization by skipping gradient computation wrt. conv2d weights. elif opts.cfg == 'stylegan2-cifar': # CIFAR's config c.G_kwargs.class_name = 'training.networks_stylegan2.Generator' c.loss_kwargs.pl_weight = 0 # Disable path length regularization (default) c.loss_kwargs.style_mixing_prob = 0 # Disable style mixing regularization (default) c.D_kwargs.architecture = 'orig' # Disable residual skip connections in D c.G_kwargs.mapping_kwargs.num_layers = 2 # Lower the number of layers in the mapping network c.ema_kimg = 500 elif opts.cfg == 'stylegan2-ext': # Aydao's config (to be tested later) c.G_kwargs.class_name = 'training.networks_stylegan2.Generator' c.loss_kwargs.pl_weight = 0 # Disable path length regularization (default) c.loss_kwargs.style_mixing_prob = 0 # Disable style mixing regularization (default) # Double Generator capacity c.G_kwargs.extended_sgan2 = True c.G_kwargs.channel_base = 32 << 10 # (default already) c.G_kwargs.extended_sgan2 = True # Double the number of feature maps c.G_kwargs.channel_max = 1024 c.D_kwargs.epilogue_kwargs.mbstd_num_channels = 4 # Mapping layer c.G_kwargs.z_dim = 1024 c.G_kwargs.w_dim = 1024 c.G_kwargs.mapping_kwargs.num_layers = 4 # TODO: test with a higher number later on c.G_kwargs.mapping_kwargs.layer_features = 1024 # TODO: test with a wider mapping network later on # TODO: Enable top-k training # TODO: try different values of c.ema_kimg # TODO: Reduce in-memory size (lower batch size, more layers with fp16, etc) else: c.G_kwargs.class_name = 'training.networks_stylegan3.Generator' c.G_kwargs.magnitude_ema_beta = 0.5 ** (c.batch_size / (20 * 1e3)) if opts.cfg == 'stylegan3-r': c.G_kwargs.conv_kernel = 1 # Use 1x1 convolutions. c.G_kwargs.channel_base *= 2 # Double the number of feature maps. c.G_kwargs.channel_max *= 2 c.G_kwargs.use_radial_filters = True # Use radially symmetric downsampling filters. c.loss_kwargs.blur_init_sigma = 10 # Blur the images seen by the discriminator. c.loss_kwargs.blur_fade_kimg = c.batch_size * 200 / 32 # Fade out the blur during the first N kimg. # Curriculum learning: start blurring all images before passing to the Discriminator, and fade it out if opts.blur_percent > 0: c.loss_kwargs.blur_init_sigma = 10 total_kimg = opts.kimg - opts.resume_kimg c.loss_kwargs.blur_fade_kimg = (opts.blur_percent / 100.0) * total_kimg # Augmentation. augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1) } if opts.aug != 'noaug': c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[opts.augpipe]) if opts.aug == 'ada': c.ada_target = opts.target if opts.aug == 'fixed': c.augment_p = opts.p # Initial augmentation strength for opts.aug == 'ada'; thanks to @dvschultz if opts.initstrength is not None: c.augment_p = opts.initstrength # Resume. if opts.resume is None: resume_desc = 'no_resume' else: if opts.resume in gen_utils.resume_specs[opts.cfg]: c.resume_pkl = gen_utils.resume_specs[opts.cfg][opts.resume] resume_desc = f'resume_{opts.resume}' else: # A local file c.resume_pkl = opts.resume resume_desc = f'resume_custom' c.ada_kimg = 100 # Make ADA react faster at the beginning. c.ema_rampup = None # Disable EMA rampup. c.loss_kwargs.blur_init_sigma = 0 # Disable blur rampup. # Performance-related toggles. if opts.fp32: c.G_kwargs.num_fp16_res = c.D_kwargs.num_fp16_res = 0 c.G_kwargs.conv_clamp = c.D_kwargs.conv_clamp = None if opts.nobench: c.cudnn_benchmark = False # Description string. desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}-{resume_desc}' if opts.desc is not None: desc += f'-{opts.desc}' # Launch. launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run) # ---------------------------------------------------------------------------- if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter # ----------------------------------------------------------------------------