lokesh6309 commited on
Commit
392a2bc
1 Parent(s): 13d1091

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ FFHQ70000_llava_v1.6_13b_4bit_prompt.jsonl filter=lfs diff=lfs merge=lfs -text
Celeba30000_llava_v1.6_13b_4bit_prompt.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
FFHQ70000_llava_v1.6_13b_4bit_prompt.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3420254d4a07b1ef1b2e20eea38589d6ff3e612e01ffe65ed68c92eaf27473cf
3
+ size 16462970
README.md CHANGED
@@ -1,3 +1,77 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-to-image
4
+ - stable-diffusion
5
+ - lora
6
+ - diffusers
7
+ - template:sd-lora
8
+ widget:
9
+ - text: A young woman with smile, wearing a purple hat.
10
+ parameters:
11
+ negative_prompt: >-
12
+ worst quality, low quality, bad anatomy, watermark, text, blurry, cartoon,
13
+ unreal
14
+ output:
15
+ url: images/output.png
16
+ base_model: runwayml/stable-diffusion-v1-5
17
+ instance_prompt: null
18
+ license: mit
19
+ ---
20
+ # pytorch_lora_weights.safetensors
21
+
22
+ <Gallery />
23
+
24
+ ## Model description
25
+
26
+ This model is a fine-tuned version of the Stable Diffusion architecture, leveraging the Low-Rank Adaptation (LoRA) technique. It has been trained using the CelebA-HQ and FFHQ datasets, both renowned for their high-quality images of human faces.
27
+
28
+ ### Training Details:
29
+
30
+ - **Base Model**: Stable Diffusion
31
+ - **Adaptation Technique**: Low-Rank Adaptation (LoRA)
32
+ - **Datasets**: CelebA-HQ (30,000 images), FFHQ (70,000 images)
33
+ - **Resolution**: resolution : 512*512 fine-tuning for detailed facial synthesis
34
+
35
+ ### Example Usages:
36
+
37
+
38
+ ```py
39
+
40
+ import torch
41
+ from diffusers import StableDiffusionPipeline,UNet2DConditionModel
42
+
43
+ pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
44
+
45
+ pipeline.load_lora_weights("phil329/face_lora_sd15", weight_name="pytorch_lora_weights.safetensors")
46
+
47
+ NEGATIVE_PROMPT = "worst quality, low quality, bad anatomy, watermark, text, blurry, cartoon, unreal"
48
+ text = 'A young woman with smile, wearing a purple hat.'
49
+
50
+ lora_image = pipeline(text,negative_prompt=NEGATIVE_PROMPT).images[0]
51
+
52
+ display(lora_image)
53
+
54
+ ```
55
+
56
+ ### Results
57
+
58
+ We use four prompts as follows:
59
+ - 'A young woman with smile, wearing a purple hat.'
60
+ - 'A middle-aged man,beard ,attractive'
61
+ - 'A girl with long blonde hair'
62
+ - 'An young man with curry hair'
63
+
64
+ The **negative prompt** are the same as the example codes. All the results are randomly generated and **not** cherry-picked.
65
+
66
+ If the generation effect is not good, try adding a negative prompt, or try different prompts and seeds.
67
+
68
+
69
+ ![Result](./images/merge.png)
70
+
71
+
72
+
73
+ ## Download model
74
+
75
+ Weights for this model are available in Safetensors format.
76
+
77
+ [Download](/phil329/face_lora_sd15/tree/main) them in the Files & versions tab.
gradio_inference_t2i_lora.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+
7
+ from PIL import Image
8
+
9
+ from diffusers import StableDiffusionPipeline,UNet2DConditionModel
10
+
11
+ NEGATIVE_PROMPT = "worst quality, low quality, bad anatomy, watermark, text, blurry, cartoon, unreal"
12
+
13
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5",subfolder='unet').to("cuda")
14
+
15
+
16
+
17
+ # unet.load_lora_weights("./exp_output/celeba_finetune/checkpoint-20000", weight_name="pytorch_lora_weights.safetensors")
18
+
19
+ pipeline = StableDiffusionPipeline.from_pretrained(
20
+ "runwayml/stable-diffusion-v1-5",
21
+ unet=unet)
22
+
23
+ pipeline.load_lora_weights("./exp_output/celeba_finetune/checkpoint-20000", weight_name="pytorch_lora_weights.safetensors")
24
+
25
+ # Define a function to process input and return output
26
+ def generate_image(text,num_batch,is_use_lora,num_inference_steps):
27
+ # Process text to generate image
28
+ if is_use_lora:
29
+ pipeline.enable_lora()
30
+ else:
31
+ pipeline.disable_lora()
32
+
33
+ print('begin inference with text:', text, 'is_use_lora:', is_use_lora)
34
+ image = pipeline(text,
35
+ num_inference_steps=num_inference_steps,
36
+ num_images_per_prompt=num_batch,
37
+ negative_prompt=NEGATIVE_PROMPT).images
38
+ return image
39
+
40
+
41
+ with gr.Blocks() as demo:
42
+
43
+ with gr.Row():
44
+ with gr.Column():
45
+ with gr.Row():
46
+ is_use_lora = gr.Checkbox(label="Use LoRA", value=False)
47
+ num_batch = gr.Number(value=4,label="Number of batch")
48
+ num_inference_steps = gr.Number(value=20,label="Number of inference steps")
49
+
50
+ text_input = gr.Textbox(lines=2, label="Input text", value="A young woman with long hair and a big smile.")
51
+ generate_button = gr.Button(value="Generate image")
52
+
53
+ # image_out = gr.Image(label="Output image", height=512,width=512)
54
+ image_out = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", object_fit="contain", height="512")
55
+
56
+ generate_button.click(generate_image, inputs=[text_input,num_batch,is_use_lora,num_inference_steps], outputs=image_out)
57
+
58
+ demo.launch(server_port=7861)
59
+
pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2009c19f6ad83bba447134f0933f9832ed0f657080d4ac11006ee3b2f4f98d5
3
+ size 3226184
train_text_to_image_lora.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
17
+
18
+ from typing import Any, Dict, Iterable, List, Optional, Union
19
+
20
+ import argparse
21
+ import logging
22
+ import math
23
+ import os
24
+ import random
25
+ import shutil
26
+ from pathlib import Path
27
+
28
+ import datasets
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ import transformers
34
+ from PIL import Image
35
+ from accelerate import Accelerator
36
+ from accelerate.logging import get_logger
37
+ from accelerate.utils import ProjectConfiguration, set_seed
38
+ from datasets import load_dataset,interleave_datasets
39
+ from huggingface_hub import create_repo, upload_folder
40
+ from packaging import version
41
+ from peft import LoraConfig
42
+ from peft.utils import get_peft_model_state_dict
43
+ from torchvision import transforms
44
+ from tqdm.auto import tqdm
45
+ from transformers import CLIPTextModel, CLIPTokenizer
46
+
47
+ import diffusers
48
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
49
+ from diffusers.optimization import get_scheduler
50
+ from diffusers.training_utils import compute_snr
51
+ from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
52
+ # from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
53
+ from diffusers.utils.import_utils import is_xformers_available
54
+ from diffusers.utils.torch_utils import is_compiled_module
55
+
56
+
57
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
+ check_min_version("0.25.0")
59
+
60
+ logger = get_logger(__name__, log_level="INFO")
61
+
62
+ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
63
+ if not isinstance(model, list):
64
+ model = [model]
65
+ for m in model:
66
+ for param in m.parameters():
67
+ # only upcast trainable parameters into fp32
68
+ if param.requires_grad:
69
+ param.data = param.to(dtype)
70
+
71
+
72
+ def parse_args():
73
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
74
+ parser.add_argument(
75
+ "--pretrained_model_name_or_path",
76
+ type=str,
77
+ default=None,
78
+ required=True,
79
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
80
+ )
81
+ parser.add_argument(
82
+ "--revision",
83
+ type=str,
84
+ default=None,
85
+ required=False,
86
+ help="Revision of pretrained model identifier from huggingface.co/models.",
87
+ )
88
+ parser.add_argument(
89
+ "--variant",
90
+ type=str,
91
+ default=None,
92
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
93
+ )
94
+ parser.add_argument(
95
+ "--dataset_json",
96
+ type=str,
97
+ default=None,
98
+ nargs="+",
99
+ help=(
100
+ "A json file containing the dataset. The file must contain a list of dictionaries, where each dictionary"
101
+ ),
102
+ )
103
+ parser.add_argument(
104
+ "--dataset_config_name",
105
+ type=str,
106
+ default=None,
107
+ help="The config of the Dataset, leave as None if there's only one config.",
108
+ )
109
+ parser.add_argument('--name_column', type=str, default='name', help='The column of the dataset containing the name of the dataset.')
110
+ parser.add_argument(
111
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
112
+ )
113
+ parser.add_argument(
114
+ "--caption_column",
115
+ type=str,
116
+ default="text",
117
+ help="The column of the dataset containing a caption or a list of captions.",
118
+ )
119
+ parser.add_argument(
120
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
121
+ )
122
+ parser.add_argument(
123
+ "--num_validation_images",
124
+ type=int,
125
+ default=4,
126
+ help="Number of images that should be generated during validation with `validation_prompt`.",
127
+ )
128
+ parser.add_argument(
129
+ "--validation_epochs",
130
+ type=int,
131
+ default=1,
132
+ help=(
133
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
134
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
135
+ ),
136
+ )
137
+ parser.add_argument(
138
+ "--max_train_samples",
139
+ type=int,
140
+ default=None,
141
+ help=(
142
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
143
+ "value if set."
144
+ ),
145
+ )
146
+ parser.add_argument(
147
+ "--output_dir",
148
+ type=str,
149
+ default="sd-model-finetuned-lora",
150
+ help="The output directory where the model predictions and checkpoints will be written.",
151
+ )
152
+ parser.add_argument(
153
+ "--cache_dir",
154
+ type=str,
155
+ default=None,
156
+ help="The directory where the downloaded models and datasets will be stored.",
157
+ )
158
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
159
+ parser.add_argument(
160
+ "--resolution",
161
+ type=int,
162
+ default=512,
163
+ help=(
164
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
165
+ " resolution"
166
+ ),
167
+ )
168
+ parser.add_argument(
169
+ "--center_crop",
170
+ default=False,
171
+ action="store_true",
172
+ help=(
173
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
174
+ " cropped. The images will be resized to the resolution first before cropping."
175
+ ),
176
+ )
177
+ parser.add_argument(
178
+ "--random_flip",
179
+ action="store_true",
180
+ help="whether to randomly flip images horizontally",
181
+ )
182
+ parser.add_argument(
183
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
184
+ )
185
+ parser.add_argument("--num_train_epochs", type=int, default=100)
186
+ parser.add_argument(
187
+ "--max_train_steps",
188
+ type=int,
189
+ default=None,
190
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
191
+ )
192
+ parser.add_argument(
193
+ "--gradient_accumulation_steps",
194
+ type=int,
195
+ default=1,
196
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
197
+ )
198
+ parser.add_argument(
199
+ "--gradient_checkpointing",
200
+ action="store_true",
201
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
202
+ )
203
+ parser.add_argument(
204
+ "--learning_rate",
205
+ type=float,
206
+ default=1e-4,
207
+ help="Initial learning rate (after the potential warmup period) to use.",
208
+ )
209
+ parser.add_argument(
210
+ "--scale_lr",
211
+ action="store_true",
212
+ default=False,
213
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
214
+ )
215
+ parser.add_argument(
216
+ "--lr_scheduler",
217
+ type=str,
218
+ default="constant",
219
+ help=(
220
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
221
+ ' "constant", "constant_with_warmup"]'
222
+ ),
223
+ )
224
+ parser.add_argument(
225
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
226
+ )
227
+ parser.add_argument(
228
+ "--snr_gamma",
229
+ type=float,
230
+ default=None,
231
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
232
+ "More details here: https://arxiv.org/abs/2303.09556.",
233
+ )
234
+ parser.add_argument(
235
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
236
+ )
237
+ parser.add_argument(
238
+ "--allow_tf32",
239
+ action="store_true",
240
+ help=(
241
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
242
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
243
+ ),
244
+ )
245
+ parser.add_argument(
246
+ "--dataloader_num_workers",
247
+ type=int,
248
+ default=0,
249
+ help=(
250
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
251
+ ),
252
+ )
253
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
254
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
255
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
256
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
257
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
258
+ parser.add_argument(
259
+ "--prediction_type",
260
+ type=str,
261
+ default=None,
262
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
263
+ )
264
+ parser.add_argument(
265
+ "--hub_model_id",
266
+ type=str,
267
+ default=None,
268
+ help="The name of the repository to keep in sync with the local `output_dir`.",
269
+ )
270
+ parser.add_argument(
271
+ "--logging_dir",
272
+ type=str,
273
+ default="logs",
274
+ help=(
275
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
276
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
277
+ ),
278
+ )
279
+ parser.add_argument(
280
+ "--mixed_precision",
281
+ type=str,
282
+ default='no',
283
+ choices=["no", "fp16", "bf16"],
284
+ help=(
285
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
286
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
287
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
288
+ ),
289
+ )
290
+ parser.add_argument(
291
+ "--report_to",
292
+ type=str,
293
+ default="tensorboard",
294
+ help=(
295
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
296
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
297
+ ),
298
+ )
299
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
300
+ parser.add_argument(
301
+ "--checkpointing_steps",
302
+ type=int,
303
+ default=500,
304
+ help=(
305
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
306
+ " training using `--resume_from_checkpoint`."
307
+ ),
308
+ )
309
+ parser.add_argument(
310
+ "--checkpoints_total_limit",
311
+ type=int,
312
+ default=None,
313
+ help=("Max number of checkpoints to store."),
314
+ )
315
+ parser.add_argument(
316
+ "--resume_from_checkpoint",
317
+ type=str,
318
+ default=None,
319
+ help=(
320
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
321
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
322
+ ),
323
+ )
324
+ parser.add_argument(
325
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
326
+ )
327
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
328
+ parser.add_argument(
329
+ "--rank",
330
+ type=int,
331
+ default=4,
332
+ help=("The dimension of the LoRA update matrices."),
333
+ )
334
+
335
+ args = parser.parse_args()
336
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
337
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
338
+ args.local_rank = env_local_rank
339
+
340
+ # Sanity checks
341
+ if args.dataset_json is None:
342
+ raise ValueError("Need either a dataset name or a training folder.")
343
+
344
+ return args
345
+
346
+
347
+ DATASET_NAME_MAPPING = {
348
+ "celeba-hq": '/mnt/pami202/blli/DATASET/CelebAMask-HQ',
349
+ "ffhq_1024": '/mnt/pami202/blli/DATASET/FFHQ',
350
+ }
351
+
352
+
353
+ def main():
354
+ args = parse_args()
355
+
356
+ logging_dir = Path(args.output_dir, args.logging_dir)
357
+
358
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
359
+
360
+ accelerator = Accelerator(
361
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
362
+ mixed_precision=args.mixed_precision,
363
+ log_with=args.report_to,
364
+ project_config=accelerator_project_config,
365
+ )
366
+ if args.report_to == "wandb":
367
+ if not is_wandb_available():
368
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
369
+ import wandb
370
+
371
+ # Make one log on every process with the configuration for debugging.
372
+ logging.basicConfig(
373
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
374
+ datefmt="%m/%d/%Y %H:%M:%S",
375
+ level=logging.INFO,
376
+ )
377
+ logger.info(accelerator.state, main_process_only=False)
378
+ if accelerator.is_local_main_process:
379
+ datasets.utils.logging.set_verbosity_warning()
380
+ transformers.utils.logging.set_verbosity_warning()
381
+ diffusers.utils.logging.set_verbosity_info()
382
+ else:
383
+ datasets.utils.logging.set_verbosity_error()
384
+ transformers.utils.logging.set_verbosity_error()
385
+ diffusers.utils.logging.set_verbosity_error()
386
+
387
+ # If passed along, set the training seed now.
388
+ if args.seed is not None:
389
+ set_seed(args.seed)
390
+
391
+ # Handle the repository creation
392
+ if accelerator.is_main_process:
393
+ if args.output_dir is not None:
394
+ os.makedirs(args.output_dir, exist_ok=True)
395
+
396
+ # Load scheduler, tokenizer and models.
397
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
398
+ tokenizer = CLIPTokenizer.from_pretrained(
399
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
400
+ )
401
+ text_encoder = CLIPTextModel.from_pretrained(
402
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
403
+ )
404
+ vae = AutoencoderKL.from_pretrained(
405
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
406
+ )
407
+ unet = UNet2DConditionModel.from_pretrained(
408
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
409
+ )
410
+ # freeze parameters of models to save more memory
411
+ unet.requires_grad_(False)
412
+ vae.requires_grad_(False)
413
+ text_encoder.requires_grad_(False)
414
+
415
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
416
+ # as these weights are only used for inference, keeping weights in full precision is not required.
417
+ weight_dtype = torch.float32
418
+ if accelerator.mixed_precision == "fp16":
419
+ weight_dtype = torch.float16
420
+ elif accelerator.mixed_precision == "bf16":
421
+ weight_dtype = torch.bfloat16
422
+
423
+ # Freeze the unet parameters before adding adapters
424
+ for param in unet.parameters():
425
+ param.requires_grad_(False)
426
+
427
+ unet_lora_config = LoraConfig(
428
+ r=args.rank,
429
+ lora_alpha=args.rank,
430
+ init_lora_weights="gaussian",
431
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
432
+ )
433
+
434
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
435
+ unet.to(accelerator.device, dtype=weight_dtype)
436
+ vae.to(accelerator.device, dtype=weight_dtype)
437
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
438
+
439
+ # Add adapter and make sure the trainable params are in float32.
440
+ unet.add_adapter(unet_lora_config)
441
+ if args.mixed_precision == "fp16":
442
+ # only upcast trainable parameters (LoRA) into fp32
443
+ cast_training_params(unet, dtype=torch.float32)
444
+
445
+ if args.enable_xformers_memory_efficient_attention:
446
+ if is_xformers_available():
447
+ import xformers
448
+
449
+ xformers_version = version.parse(xformers.__version__)
450
+ if xformers_version == version.parse("0.0.16"):
451
+ logger.warning(
452
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
453
+ )
454
+ unet.enable_xformers_memory_efficient_attention()
455
+ else:
456
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
457
+
458
+ lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
459
+
460
+ if args.gradient_checkpointing:
461
+ unet.enable_gradient_checkpointing()
462
+
463
+ # Enable TF32 for faster training on Ampere GPUs,
464
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
465
+ if args.allow_tf32:
466
+ torch.backends.cuda.matmul.allow_tf32 = True
467
+
468
+ if args.scale_lr:
469
+ args.learning_rate = (
470
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
471
+ )
472
+
473
+ # Initialize the optimizer
474
+ if args.use_8bit_adam:
475
+ try:
476
+ import bitsandbytes as bnb
477
+ except ImportError:
478
+ raise ImportError(
479
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
480
+ )
481
+
482
+ optimizer_cls = bnb.optim.AdamW8bit
483
+ else:
484
+ optimizer_cls = torch.optim.AdamW
485
+
486
+ optimizer = optimizer_cls(
487
+ lora_layers,
488
+ lr=args.learning_rate,
489
+ betas=(args.adam_beta1, args.adam_beta2),
490
+ weight_decay=args.adam_weight_decay,
491
+ eps=args.adam_epsilon,
492
+ )
493
+
494
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
495
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
496
+
497
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
498
+ # download the dataset.
499
+
500
+ dataset = load_dataset('json', data_files=args.dataset_json)
501
+
502
+ # Preprocessing the datasets.
503
+ # We need to tokenize inputs and targets.
504
+ column_names = dataset["train"].column_names
505
+
506
+ # 6. Get the column names for input/target.
507
+ name_column = args.name_column
508
+ image_column = args.image_column
509
+ caption_column = args.caption_column
510
+
511
+ # Preprocessing the datasets.
512
+ # We need to tokenize input captions and transform the images.
513
+ def tokenize_captions(examples, is_train=True):
514
+ captions = []
515
+ for caption in examples[caption_column]:
516
+ if isinstance(caption, str):
517
+ captions.append(caption)
518
+ elif isinstance(caption, (list, np.ndarray)):
519
+ # take a random caption if there are multiple
520
+ captions.append(random.choice(caption) if is_train else caption[0])
521
+ else:
522
+ raise ValueError(
523
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
524
+ )
525
+ inputs = tokenizer(
526
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
527
+ )
528
+ return inputs.input_ids
529
+
530
+ # Preprocessing the datasets.
531
+ train_transforms = transforms.Compose(
532
+ [
533
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
534
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
535
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
536
+ transforms.ToTensor(),
537
+ transforms.Normalize([0.5], [0.5]),
538
+ ]
539
+ )
540
+
541
+ def unwrap_model(model):
542
+ model = accelerator.unwrap_model(model)
543
+ model = model._orig_mod if is_compiled_module(model) else model
544
+ return model
545
+
546
+ def preprocess_train(examples):
547
+ images = []
548
+ for name,image in zip(examples[name_column],examples[image_column]):
549
+ path = DATASET_NAME_MAPPING[name]
550
+ images.append(Image.open(os.path.join(path, image)).convert("RGB"))
551
+
552
+ # images = [image.convert("RGB") for image in examples[image_column]]
553
+ examples["pixel_values"] = [train_transforms(image) for image in images]
554
+ examples["input_ids"] = tokenize_captions(examples)
555
+ return examples
556
+
557
+ with accelerator.main_process_first():
558
+ if args.max_train_samples is not None:
559
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
560
+ # Set the training transforms
561
+ train_dataset = dataset["train"].with_transform(preprocess_train)
562
+
563
+ def collate_fn(examples):
564
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
565
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
566
+ input_ids = torch.stack([example["input_ids"] for example in examples])
567
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
568
+
569
+ # DataLoaders creation:
570
+ train_dataloader = torch.utils.data.DataLoader(
571
+ train_dataset,
572
+ shuffle=True,
573
+ collate_fn=collate_fn,
574
+ batch_size=args.train_batch_size,
575
+ num_workers=args.dataloader_num_workers,
576
+ )
577
+
578
+ # Scheduler and math around the number of training steps.
579
+ overrode_max_train_steps = False
580
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
581
+ if args.max_train_steps is None:
582
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
583
+ overrode_max_train_steps = True
584
+
585
+ lr_scheduler = get_scheduler(
586
+ args.lr_scheduler,
587
+ optimizer=optimizer,
588
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
589
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
590
+ )
591
+
592
+ # Prepare everything with our `accelerator`.
593
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
594
+ unet, optimizer, train_dataloader, lr_scheduler
595
+ )
596
+
597
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
598
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
599
+ if overrode_max_train_steps:
600
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
601
+ # Afterwards we recalculate our number of training epochs
602
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
603
+
604
+ # We need to initialize the trackers we use, and also store our configuration.
605
+ # The trackers initializes automatically on the main process.
606
+ if accelerator.is_main_process:
607
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
608
+
609
+ # Train!
610
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
611
+
612
+ logger.info("***** Running training *****")
613
+ logger.info(f" Num examples = {len(train_dataset)}")
614
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
615
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
616
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
617
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
618
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
619
+ global_step = 0
620
+ first_epoch = 0
621
+
622
+ # Potentially load in the weights and states from a previous save
623
+ if args.resume_from_checkpoint:
624
+ if args.resume_from_checkpoint != "latest":
625
+ path = os.path.basename(args.resume_from_checkpoint)
626
+ # path = args.resume_from_checkpoint
627
+ else:
628
+ # Get the most recent checkpoint
629
+ dirs = os.listdir(args.output_dir)
630
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
631
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
632
+ path = dirs[-1] if len(dirs) > 0 else None
633
+
634
+ if path is None:
635
+ accelerator.print(
636
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
637
+ )
638
+ args.resume_from_checkpoint = None
639
+ initial_global_step = 0
640
+ else:
641
+ accelerator.print(f"Resuming from checkpoint {path}")
642
+ accelerator.load_state(os.path.join(args.output_dir, path))
643
+ # accelerator.load_state(path)
644
+ global_step = int(path.split("-")[1])
645
+
646
+ initial_global_step = global_step
647
+ first_epoch = global_step // num_update_steps_per_epoch
648
+ else:
649
+ initial_global_step = 0
650
+
651
+ progress_bar = tqdm(
652
+ range(0, args.max_train_steps),
653
+ initial=initial_global_step,
654
+ desc="Steps",
655
+ # Only show the progress bar once on each machine.
656
+ disable=not accelerator.is_local_main_process,
657
+ )
658
+
659
+ for epoch in range(first_epoch, args.num_train_epochs):
660
+ unet.train()
661
+ train_loss = 0.0
662
+ for step, batch in enumerate(train_dataloader):
663
+ with accelerator.accumulate(unet):
664
+ # Convert images to latent space
665
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
666
+ latents = latents * vae.config.scaling_factor
667
+
668
+ # Sample noise that we'll add to the latents
669
+ noise = torch.randn_like(latents)
670
+ if args.noise_offset:
671
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
672
+ noise += args.noise_offset * torch.randn(
673
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
674
+ )
675
+
676
+ bsz = latents.shape[0]
677
+ # Sample a random timestep for each image
678
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
679
+ timesteps = timesteps.long()
680
+
681
+ # Add noise to the latents according to the noise magnitude at each timestep
682
+ # (this is the forward diffusion process)
683
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
684
+
685
+ # Get the text embedding for conditioning
686
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
687
+
688
+ # Get the target for loss depending on the prediction type
689
+ if args.prediction_type is not None:
690
+ # set prediction_type of scheduler if defined
691
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
692
+
693
+ if noise_scheduler.config.prediction_type == "epsilon":
694
+ target = noise
695
+ elif noise_scheduler.config.prediction_type == "v_prediction":
696
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
697
+ else:
698
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
699
+
700
+ # Predict the noise residual and compute loss
701
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
702
+
703
+ if args.snr_gamma is None:
704
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
705
+ else:
706
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
707
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
708
+ # This is discussed in Section 4.2 of the same paper.
709
+ snr = compute_snr(noise_scheduler, timesteps)
710
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
711
+ dim=1
712
+ )[0]
713
+ if noise_scheduler.config.prediction_type == "epsilon":
714
+ mse_loss_weights = mse_loss_weights / snr
715
+ elif noise_scheduler.config.prediction_type == "v_prediction":
716
+ mse_loss_weights = mse_loss_weights / (snr + 1)
717
+
718
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
719
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
720
+ loss = loss.mean()
721
+
722
+ # Gather the losses across all processes for logging (if we use distributed training).
723
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
724
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
725
+
726
+ # Backpropagate
727
+ accelerator.backward(loss)
728
+ if accelerator.sync_gradients:
729
+ params_to_clip = lora_layers
730
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
731
+ optimizer.step()
732
+ lr_scheduler.step()
733
+ optimizer.zero_grad()
734
+
735
+ # Checks if the accelerator has performed an optimization step behind the scenes
736
+ if accelerator.sync_gradients:
737
+ progress_bar.update(1)
738
+ global_step += 1
739
+ accelerator.log({"train_loss": train_loss}, step=global_step)
740
+ train_loss = 0.0
741
+
742
+ if global_step % args.checkpointing_steps == 0:
743
+ if accelerator.is_main_process:
744
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
745
+ if args.checkpoints_total_limit is not None:
746
+ checkpoints = os.listdir(args.output_dir)
747
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
748
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
749
+
750
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
751
+ if len(checkpoints) >= args.checkpoints_total_limit:
752
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
753
+ removing_checkpoints = checkpoints[0:num_to_remove]
754
+
755
+ logger.info(
756
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
757
+ )
758
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
759
+
760
+ for removing_checkpoint in removing_checkpoints:
761
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
762
+ shutil.rmtree(removing_checkpoint)
763
+
764
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
765
+ accelerator.save_state(save_path)
766
+
767
+ unwrapped_unet = unwrap_model(unet)
768
+ unet_lora_state_dict = convert_state_dict_to_diffusers(
769
+ get_peft_model_state_dict(unwrapped_unet)
770
+ )
771
+
772
+ StableDiffusionPipeline.save_lora_weights(
773
+ save_directory=save_path,
774
+ unet_lora_layers=unet_lora_state_dict,
775
+ safe_serialization=True,
776
+ )
777
+
778
+ logger.info(f"Saved state to {save_path}")
779
+
780
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
781
+ progress_bar.set_postfix(**logs)
782
+
783
+ if global_step >= args.max_train_steps:
784
+ break
785
+
786
+ if accelerator.is_main_process:
787
+ if args.validation_prompt is not None and (epoch % args.validation_epochs == 0 or epoch == 0):
788
+ logger.info(
789
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
790
+ f" {args.validation_prompt}."
791
+ )
792
+ # create pipeline
793
+ pipeline = DiffusionPipeline.from_pretrained(
794
+ args.pretrained_model_name_or_path,
795
+ unet=unwrap_model(unet),
796
+ revision=args.revision,
797
+ variant=args.variant,
798
+ torch_dtype=weight_dtype,
799
+ )
800
+ pipeline = pipeline.to(accelerator.device)
801
+ pipeline.set_progress_bar_config(disable=True)
802
+
803
+ # run inference
804
+ generator = torch.Generator(device=accelerator.device)
805
+ if args.seed is not None:
806
+ generator = generator.manual_seed(args.seed)
807
+ images = []
808
+ with torch.cuda.amp.autocast():
809
+ for _ in range(args.num_validation_images):
810
+ images.append(
811
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
812
+ )
813
+
814
+ for tracker in accelerator.trackers:
815
+ if tracker.name == "tensorboard":
816
+ np_images = np.stack([np.asarray(img) for img in images])
817
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
818
+ if tracker.name == "wandb":
819
+ tracker.log(
820
+ {
821
+ "validation": [
822
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
823
+ for i, image in enumerate(images)
824
+ ]
825
+ }
826
+ )
827
+
828
+ del pipeline
829
+ torch.cuda.empty_cache()
830
+
831
+ # Save the lora layers
832
+ accelerator.wait_for_everyone()
833
+ if accelerator.is_main_process:
834
+ unet = unet.to(torch.float32)
835
+
836
+ unwrapped_unet = unwrap_model(unet)
837
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
838
+ StableDiffusionPipeline.save_lora_weights(
839
+ save_directory=args.output_dir,
840
+ unet_lora_layers=unet_lora_state_dict,
841
+ safe_serialization=True,
842
+ )
843
+
844
+ # Final inference
845
+ # Load previous pipeline
846
+ if args.validation_prompt is not None:
847
+ pipeline = DiffusionPipeline.from_pretrained(
848
+ args.pretrained_model_name_or_path,
849
+ revision=args.revision,
850
+ variant=args.variant,
851
+ torch_dtype=weight_dtype,
852
+ )
853
+ pipeline = pipeline.to(accelerator.device)
854
+
855
+ # load attention processors
856
+ pipeline.load_lora_weights(args.output_dir)
857
+
858
+ # run inference
859
+ generator = torch.Generator(device=accelerator.device)
860
+ if args.seed is not None:
861
+ generator = generator.manual_seed(args.seed)
862
+ images = []
863
+ with torch.cuda.amp.autocast():
864
+ for _ in range(args.num_validation_images):
865
+ images.append(
866
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
867
+ )
868
+
869
+ for tracker in accelerator.trackers:
870
+ if len(images) != 0:
871
+ if tracker.name == "tensorboard":
872
+ np_images = np.stack([np.asarray(img) for img in images])
873
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
874
+ if tracker.name == "wandb":
875
+ tracker.log(
876
+ {
877
+ "test": [
878
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
879
+ for i, image in enumerate(images)
880
+ ]
881
+ }
882
+ )
883
+
884
+ accelerator.end_training()
885
+
886
+
887
+ if __name__ == "__main__":
888
+ main()