import torch.nn.init as init import argparse from cgitb import text import copy import gc import itertools import logging import math import os import random import shutil from tkinter import NO import warnings from contextlib import nullcontext from pathlib import Path import PIL.Image import PIL.ImageOps import numpy as np from sympy import N import torch import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast # from transformer_sd3 import SD3Transformer2DModel import diffusers from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, StableDiffusion3Pipeline, SD3Transformer2DModel, StableDiffusion3InstructPix2PixPipeline ) from diffusers.optimization import get_scheduler from diffusers.utils import ( check_min_version, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module import accelerate import datasets import PIL import requests import torch.nn as nn import torch.nn.functional as F from os.path import join from datasets import load_dataset from packaging import version def load_text_encoders(class_one, class_two, class_three): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) text_encoder_three = class_three.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two, text_encoder_three def tokenize_prompt(tokenizer, prompt, max_sequence_length=77): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length, text_encoder_dtype, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if text_input_ids is None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=text_encoder_dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, text_encoder_dtype, device=None, num_images_per_prompt: int = 1, text_input_ids=None ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if text_input_ids is None: text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=text_encoder_dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, pooled_prompt_embeds def encode_prompt( text_encoders, tokenizers, prompt: str, max_sequence_length=None, text_encoders_dtypes=[torch.float32,torch.float32,torch.float32], device=None, num_images_per_prompt: int = 1, text_input_ids_list=None ): prompt = [prompt] if isinstance(prompt, str) else prompt clip_prompt_embeds_list = [] clip_pooled_prompt_embeds_list = [] clip_tokenizers = tokenizers[:2] clip_text_encoders = text_encoders[:2] clip_text_encoders_dtypes = text_encoders_dtypes[:2] if text_input_ids_list is not None: clip_text_input_ids_list = text_input_ids_list[:2] else: clip_text_input_ids_list = [None, None] zipped_text_encoders = zip(clip_tokenizers, clip_text_encoders, clip_text_encoders_dtypes, clip_text_input_ids_list) for tokenizer, text_encoder, clip_text_encoder_dtype, text_input_ids in zipped_text_encoders: prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoder, tokenizer=tokenizer, prompt=prompt, text_encoder_dtype=clip_text_encoder_dtype, device=device if device is not None else text_encoder.device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids, ) clip_prompt_embeds_list.append(prompt_embeds) clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) if text_input_ids_list is not None: t5_text_input_ids = text_input_ids_list[-1] else: t5_text_input_ids = None t5_prompt_embed = _encode_prompt_with_t5( text_encoders[-1], tokenizers[-1], max_sequence_length, clip_text_encoders_dtypes[-1], prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[-1].device, text_input_ids=t5_text_input_ids ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) return prompt_embeds, pooled_prompt_embeds logger = get_logger(__name__, log_level="INFO") DATASET_NAME_MAPPING = { "BleachNick/UltraEdit_500k": ("source_image", "edited_image", "edit_prompt"), } WANDB_TABLE_COL_NAMES = ["source_image", "edited_image", "edit_prompt"] def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--ori_model_name_or_path", type=str, default=None, help="Path to ori_model_name_or_path.", ) parser.add_argument( "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"] ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--mode_scale", type=float, default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) parser.add_argument( "--optimizer", type=str, default="AdamW", help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", ) parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( "--prodigy_beta3", type=float, default=None, help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument( "--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument( "--prodigy_safeguard_warmup", type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) parser.add_argument( "--variant", type=str, default=None, help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument( "--train_data_jsonl", type=str, default=None, help=( "A folder containing the training data. Folder contents must follow the structure described in" " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." ), ) parser.add_argument( "--original_image_column", type=str, default="source_image", help="The column of the dataset containing the original image on which edits where made.", ) parser.add_argument( "--config_file", type=str, default=None, ) parser.add_argument( "--edited_image_column", type=str, default="edited_image", help="The column of the dataset containing the edited image.", ) parser.add_argument( "--edit_prompt_column", type=str, default="edit_prompt", help="The column of the dataset containing the edit instruction.", ) parser.add_argument( "--val_image_url", type=str, default=None, help="URL to the original image that you would like to edit (used during inference for debugging purposes).", ) parser.add_argument( '--val_mask_url', type=str, default=None, help="URL to the mask image that you would like to edit (used during inference for debugging purposes).", ) parser.add_argument( "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." ) parser.add_argument( "--num_validation_images", type=int, default=4, help="Number of images that should be generated during validation with `validation_prompt`.", ) parser.add_argument( "--validation_epochs", type=int, default=1, help=( "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) parser.add_argument( "--validation_step", type=int, default=5000, help=( "Run fine-tuning validation every X steps. The validation process consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) parser.add_argument( "--top_training_data_sample", type=int, default=None, help="Number of top samples to use for training, ranked by clip-sim-dit. If None, use the full dataset.", ) parser.add_argument( "--max_train_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ), ) parser.add_argument( "--output_dir", type=str, default="sd3_edit", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=256, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--eval_resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--max_sequence_length", type=int, default=77, help="Maximum sequence length to use with with the T5 text encoder", ) parser.add_argument( "--center_crop", default=False, action="store_true", help=( "Whether to center crop the input images to the resolution. If not set, the images will be randomly" " cropped. The images will be resized to the resolution first before cropping." ), ) parser.add_argument( "--random_flip", action="store_true", help="whether to randomly flip images horizontally", ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="cosine", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--conditioning_dropout_prob", type=float, default=None, help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--text_encoder_lr", type=float, default=5e-6, help="Text encoder learning rate to use.", ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" " training using `--resume_from_checkpoint`." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store."), ) parser.add_argument( "--train_text_encoder", action="store_true" ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--do_mask", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--mask_column", type=str, default="mask_image", help="The column of the dataset containing the original image`s mask.", ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank # Sanity checks if args.dataset_name is None and args.train_data_jsonl is None: raise ValueError("Need either a dataset name or a training folder.") # default to using the same revision for the non-ema model if not specified return args def combine_rgb_and_mask_to_rgba(rgb_image, mask_image): # Ensure the input images are the same size if rgb_image.size != mask_image.size: raise ValueError("The RGB image and the mask image must have the same dimensions") # Convert the mask image to 'L' mode (grayscale) if it is not if mask_image.mode != 'L': mask_image = mask_image.convert('L') # Split the RGB image into its three channels r, g, b = rgb_image.split() # Combine the RGB channels with the mask to form an RGBA image rgba_image = Image.merge("RGBA", (r, g, b, mask_image)) return rgba_image def convert_to_np(image, resolution): try: if isinstance(image, str): if image == "NONE": image = PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255)) else: image = PIL.Image.open(image) image = image.convert("RGB").resize((resolution, resolution)) return np.array(image).transpose(2, 0, 1) except Exception as e: print("Load error", image) print(e) # New blank image image = PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255)) return np.array(image).transpose(2, 0, 1) def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModelWithProjection": from transformers import CLIPTextModelWithProjection return CLIPTextModelWithProjection elif model_class == "T5EncoderModel": from transformers import T5EncoderModel return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") def main(): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) from accelerate import DistributedDataParallelKwargs as DDPK kwargs = DDPK(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, kwargs_handlers=[kwargs], ) if torch.backends.mps.is_available(): accelerator.native_amp = False def download_image(path_or_url,resolution=512): # Check if path_or_url is a local file path if path_or_url is None: # return a white RBG image image return PIL.Image.new("RGB", (resolution, resolution), (255, 255, 255)) if os.path.exists(path_or_url): image = Image.open(path_or_url).convert("RGB").resize((resolution, resolution)) else: image = Image.open(requests.get(path_or_url, stream=True).raw).convert("RGB") image = PIL.ImageOps.exif_transpose(image) return image generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load scheduler, tokenizer and models. # Load the tokenizers tokenizer_one = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, ) tokenizer_two = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, ) tokenizer_three = T5TokenizerFast.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_3", revision=args.revision, ) # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision ) text_encoder_cls_two = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) text_encoder_cls_three = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" ) # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, ) transformer = SD3Transformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) # TODO logger.info("Initializing the new channel of DIT from the pretrained DIT.") in_channels = int(1.5 * transformer.config.in_channels) if args.do_mask else 2 * transformer.config.in_channels # 48 for mask out_channels = transformer.pos_embed.proj.out_channels load_num_channel = transformer.config.in_channels print("Do mask",args.do_mask) print("new in_channels",in_channels) print("load_num_channel",load_num_channel) transformer.register_to_config(in_channels=in_channels) print("transformer.pos_embed.proj.weight.shape", transformer.pos_embed.proj.weight.shape) print("load_num_channel", load_num_channel) with torch.no_grad(): new_proj = nn.Conv2d( in_channels, out_channels, kernel_size=(transformer.config.patch_size, transformer.config.patch_size), stride=transformer.config.patch_size, bias=True ) print("new_proj", new_proj) new_proj.weight.zero_() # init.kaiming_normal_(new_proj.weight, mode='fan_out', nonlinearity='relu') # if new_proj.bias is not None and transformer.pos_embed.proj.bias is not None: # new_proj.bias.copy_(transformer.pos_embed.proj.bias) # else: # if new_proj.bias is not None: # new_proj.bias.zero_() new_proj = new_proj.to(transformer.pos_embed.proj.weight.dtype) new_proj.weight[:, :load_num_channel, :, :].copy_(transformer.pos_embed.proj.weight) new_proj.bias.copy_(transformer.pos_embed.proj.bias) print("new_proj", new_proj.weight.shape) print("transformer.pos_embed.proj", transformer.pos_embed.proj.weight.shape) transformer.pos_embed.proj = new_proj for param in transformer.parameters(): param.requires_grad = True transformer.requires_grad_(True) vae.requires_grad_(False) if args.train_text_encoder: text_encoder_one.requires_grad_(True) text_encoder_two.requires_grad_(True) text_encoder_three.requires_grad_(True) else: text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) text_encoder_three.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: # due to pytorch#99272, MPS does not yet support bfloat16. raise ValueError( "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) vae.to(accelerator.device, dtype=torch.float32) if not args.train_text_encoder: text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) text_encoder_three.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() text_encoder_three.gradient_checkpointing_enable() def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): if isinstance(unwrap_model(model), SD3Transformer2DModel): unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): hidden_size = unwrap_model(model).config.hidden_size if hidden_size == 768: unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) elif hidden_size == 1280: unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) else: unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3")) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") # make sure to pop weight so that corresponding model is not saved again weights.pop() def load_model_hook(models, input_dir): for _ in range(len(models)): # pop models so that they are not loaded again model = models.pop() # load diffusers style into model if isinstance(unwrap_model(model), SD3Transformer2DModel): load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): try: load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: try: load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: raise ValueError(f"Couldn't load the model of type: ({type(model)}).") else: raise ValueError(f"Unsupported model found: {type(model)=}") del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { "params": text_encoder_one.parameters(), "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } text_parameters_two_with_lr = { "params": text_encoder_two.parameters(), "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } text_parameters_three_with_lr = { "params": text_encoder_three.parameters(), "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, text_parameters_two_with_lr, text_parameters_three_with_lr, ] else: params_to_optimize = [transformer_parameters_with_lr] if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): logger.warning( f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." "Defaulting to adamW" ) args.optimizer = "adamw" # Initialize the optimizer if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" ) if args.optimizer.lower() == "adamw": if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) if args.optimizer.lower() == "prodigy": try: import prodigyopt except ImportError: raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy if args.learning_rate <= 0.1: logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) if args.train_text_encoder and args.text_encoder_lr: logger.warning( f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " f"When using prodigy only learning_rate is used as the initial learning rate." ) # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be # --learning_rate params_to_optimize[1]["lr"] = args.learning_rate params_to_optimize[2]["lr"] = args.learning_rate params_to_optimize[3]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, use_bias_correction=args.prodigy_use_bias_correction, safeguard_warmup=args.prodigy_safeguard_warmup, ) text_encoders_dtypes = [text_encoder_one.dtype, text_encoder_two.dtype, text_encoder_three.dtype] if not args.train_text_encoder: tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three] text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three] def compute_text_embeddings(prompt, text_encoders, tokenizers,text_encoders_dtypes): with torch.no_grad(): prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders, tokenizers, prompt, args.max_sequence_length, text_encoders_dtypes ) prompt_embeds = prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, ) else: if args.train_data_jsonl is not None: dataset = load_dataset( "json", data_files=args.train_data_jsonl, cache_dir=args.cache_dir, # split="train" ) # See more about loading custom images at # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder # Preprocessing the datasets. # We need to tokenize inputs and targets. column_names = dataset["train"].column_names # 6. Get the column names for input/target. dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) if args.original_image_column is None: original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] else: original_image_column = args.original_image_column if original_image_column not in column_names: raise ValueError( f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}" ) if args.edit_prompt_column is None: edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1] else: edit_prompt_column = args.edit_prompt_column if edit_prompt_column not in column_names: raise ValueError( f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}" ) if args.edited_image_column is None: edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2] else: edited_image_column = args.edited_image_column if edited_image_column not in column_names: raise ValueError( f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" ) # Preprocessing the datasets. # We need to tokenize input captions and transform the images. # def tokenize_captions(captions): # inputs = tokenizer( # captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" # ) # return inputs.input_ids # Preprocessing the datasets. train_transforms = transforms.Compose( [ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), ] ) def preprocess_images(examples): original_images = np.concatenate( [convert_to_np(image, args.resolution) for image in examples[original_image_column]] ) edited_images = np.concatenate( [convert_to_np(image, args.resolution) for image in examples[edited_image_column]] ) if args.do_mask: mask_images = np.concatenate( [convert_to_np(image, args.resolution) for image in examples[args.mask_column]] ) # We need to ensure that the original and the edited images undergo the same # augmentation transforms. images = np.concatenate([original_images, edited_images, mask_images]) images = torch.tensor(images) images = 2 * (images / 255) - 1 # mask_index = torch.tensor([image == "NONE" for image in examples[args.mask_column]],dtype=torch.bool) # return train_transforms(images),mask_index return train_transforms(images) # We need to ensure that the original and the edited images undergo the same # augmentation transforms. images = np.concatenate([original_images, edited_images]) images = torch.tensor(images) images = 2 * (images / 255) - 1 return train_transforms(images) def preprocess_train(examples): # Preprocess images. # Since the original and edited images were concatenated before # applying the transformations, we need to separate them and reshape # them accordingly. preprocessed_images = preprocess_images(examples) if not args.do_mask: # preprocessed_images = preprocess_images(examples) original_images, edited_images = preprocessed_images.chunk(2) else: # preprocessed_images = preprocess_images(examples) # preprocessed_images,mask_index = preprocess_images(examples) original_images, edited_images, mask_images = preprocessed_images.chunk(3) mask_images = mask_images.reshape(-1, 3, args.resolution, args.resolution) # examples["mask_index"] = mask_index examples["mask_pixel_values"] = mask_images original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) examples["original_pixel_values"] = original_images examples["edited_pixel_values"] = edited_images # Preprocess the captions. # captions = list(examples[edit_prompt_column]) # examples[edit_prompt_column] = captions return examples with accelerator.main_process_first(): if args.top_training_data_sample is not None: dataset["train"] = dataset["train"].select(range(args.top_training_data_sample)).shuffle(seed=args.seed) if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) def collate_fn(examples): original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples]) original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() prompts = [example[edit_prompt_column] for example in examples] if args.do_mask: mask_pixel_values = torch.stack([example["mask_pixel_values"] for example in examples]) mask_pixel_values = mask_pixel_values.to(memory_format=torch.contiguous_format).float() return { "original_pixel_values": original_pixel_values, "edited_pixel_values": edited_pixel_values, edit_prompt_column: prompts, "mask_pixel_values": mask_pixel_values, } else: return { "original_pixel_values": original_pixel_values, "edited_pixel_values": edited_pixel_values, edit_prompt_column: prompts, } # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. if args.train_text_encoder: ( transformer, text_encoder_one, text_encoder_two, text_encoder_three, optimizer, train_dataloader, lr_scheduler, ) = accelerator.prepare( transformer, text_encoder_one, text_encoder_two, text_encoder_three, optimizer, train_dataloader, lr_scheduler, ) else: transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler ) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. # We need to recalculate our total training steps as the size of the training dataloader may have changed. if accelerator.is_main_process: pretrained_path = args.pretrained_model_name_or_path pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( pretrained_path, vae=vae, text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_3=accelerator.unwrap_model(text_encoder_three), transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) generator = torch.Generator(device=accelerator.device).manual_seed( args.seed) if args.seed else None if args.do_mask: original_image = download_image(args.val_image_url, args.eval_resolution) mask_image = download_image(args.val_mask_url, args.eval_resolution) else: original_image = download_image(args.val_image_url, args.eval_resolution) mask_image = None edited_images = [] with torch.autocast( str(accelerator.device).replace(":0", ""), enabled=(accelerator.mixed_precision == "fp16") | ( accelerator.mixed_precision == "bf16") ): for i in range(args.num_validation_images): edited_images.append( pipeline( args.validation_prompt, image=original_image, mask_img=mask_image, num_inference_steps=50, image_guidance_scale=1.5, guidance_scale=7.5, generator=generator, ).images[0] ) path = join(args.output_dir, f"start_test") os.makedirs(path, exist_ok=True) original_image.save(join(path, f"original.jpg")) for idx, edited_image in enumerate(edited_images): edited_image.save(join(path, f"sample_{idx}.jpg")) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs print('=========num_update_steps_per_epoch==========', num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: accelerator.init_trackers("instruct-pix2pix_sd3", config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch resume_global_step = global_step * args.gradient_accumulation_steps resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) else: initial_global_step = 0 # Only show the progress bar once on each machine. progress_bar = tqdm(range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", disable=not accelerator.is_local_main_process) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma # with torch.autograd.set_detect_anomaly(True): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() if args.train_text_encoder: text_encoder_one.train() text_encoder_two.train() text_encoder_three.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step # if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: # if step % args.gradient_accumulation_steps == 0: # progress_bar.update(1) # continue models_to_accumulate = [transformer] if args.train_text_encoder: models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three]) with accelerator.accumulate(models_to_accumulate): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. # So, first, convert images to latent space.] pixel_values = batch["edited_pixel_values"].to(dtype=vae.dtype) prompt = batch[edit_prompt_column] if not args.train_text_encoder: prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( prompt, text_encoders, tokenizers,text_encoders_dtypes ) else: tokens_one = tokenize_prompt(tokenizer_one, prompt) tokens_two = tokenize_prompt(tokenizer_two, prompt) tokens_three = tokenize_prompt(tokenizer_three, prompt, args.max_sequence_length) latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor latents = latents.to(dtype=weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly if args.weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") u = torch.nn.functional.sigmoid(u) elif args.weighting_scheme == "mode": u = torch.rand(size=(bsz,), device="cpu") u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) else: u = torch.rand(size=(bsz,), device="cpu") indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # Get the additional image embedding for conditioning. # Instead of getting a diagonal Gaussian here, we simply take the mode. original_image_embeds = vae.encode(batch["original_pixel_values"].to(vae.dtype)).latent_dist.mode() concatenated_noisy_latents = torch.cat([noisy_model_input, original_image_embeds], dim=1) if args.do_mask: mask_embeds = vae.encode(batch["mask_pixel_values"].to(vae.dtype)).latent_dist.mode() concatenated_noisy_latents = torch.cat([concatenated_noisy_latents, mask_embeds], dim=1) # Predict the noise residual if not args.train_text_encoder: model_pred = transformer( hidden_states=concatenated_noisy_latents, timestep=timesteps, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False, # mask_index = mask_index )[0] else: prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], tokenizers=[tokenizer_one, tokenizer_two, tokenizer_three], prompt=prompt, text_input_ids_list=[tokens_one, tokens_two, tokens_three], max_sequence_length=args.max_sequence_length, text_encoders_dtypes = text_encoders_dtypes ) model_pred = transformer( hidden_states=concatenated_noisy_latents, timestep=timesteps, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False, mask_index=mask_index )[0] model_pred = model_pred * (-sigmas) + noisy_model_input # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss if args.weighting_scheme == "sigma_sqrt": weighting = (sigmas ** -2.0).float() elif args.weighting_scheme == "cosmap": bot = 1 - 2 * sigmas + 2 * sigmas ** 2 weighting = 2 / (math.pi * bot) else: weighting = torch.ones_like(sigmas) target = latents # Conditioning dropout to support classifier-free guidance during inference. For more details # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. # Concatenate the `original_image_embeds` with the `noisy_latents`. # Get the target for loss depending on the prediction type loss = torch.mean( (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1, ) loss = loss.mean() avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( itertools.chain( transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters(), text_encoder_three.parameters(), ) if args.train_text_encoder else transformer.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") if accelerator.is_main_process: if ( (args.val_image_url is not None) and (args.validation_prompt is not None) and (global_step % args.validation_step == 0) ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) # create pipeline # if not args.train_text_encoder: # text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( # text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three # ) if args.do_mask: pretrained_path = args.ori_model_name_or_path else: pretrained_path = args.pretrained_model_name_or_path pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( pretrained_path, vae=vae, text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_3=accelerator.unwrap_model(text_encoder_three), transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) generator = torch.Generator(device=accelerator.device).manual_seed( args.seed) if args.seed else None # run inference if args.do_mask: original_image = download_image(args.val_image_url,args.eval_resolution) mask_image = download_image(args.val_mask_url,args.eval_resolution) else: original_image = download_image(args.val_image_url,args.eval_resolution) mask_image = None edited_images = [] with torch.autocast( str(accelerator.device).replace(":0", ""), enabled=(accelerator.mixed_precision == "fp16") | ( accelerator.mixed_precision == "bf16") ): for i in range(args.num_validation_images): edited_images.append( pipeline( args.validation_prompt, image=original_image, mask_img=mask_image, num_inference_steps=50, image_guidance_scale=1.5, guidance_scale=7.5, generator=generator, ).images[0] ) for tracker in accelerator.trackers: path = join(args.output_dir, f"eval_{global_step}") os.makedirs(path, exist_ok=True) original_image.save(join(path, f"original.jpg")) if tracker.name == "wandb": wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) for idx, edited_image in enumerate(edited_images): wandb_table.add_data( wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt ) # save in the dir as well tracker.log({"validation": wandb_table}) for idx, edited_image in enumerate(edited_images): edited_image.save(join(path, f"sample_{idx}.jpg")) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) text_encoder_two = unwrap_model(text_encoder_two) text_encoder_three = unwrap_model(text_encoder_three) pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=transformer, text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, text_encoder_3=text_encoder_three, ) else: pipeline = StableDiffusion3InstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=transformer ) pipeline.save_pretrained(args.output_dir) accelerator.end_training() if __name__ == "__main__": args = parse_args() main()