Spaces:
Runtime error
Runtime error
mkshing
commited on
Commit
•
47390c8
1
Parent(s):
e2a20af
apply v0.2.0
Browse files- app_inference.py +1 -1
- app_training.py +1 -2
- inference.py +10 -2
- requirements.txt +1 -1
- train_svdiff.py +108 -64
- trainer.py +2 -0
app_inference.py
CHANGED
@@ -12,7 +12,7 @@ from utils import find_exp_dirs
|
|
12 |
|
13 |
SAMPLE_MODEL_IDS = [
|
14 |
'svdiff-library/svdiff_dog_example',
|
15 |
-
'
|
16 |
]
|
17 |
|
18 |
|
|
|
12 |
|
13 |
SAMPLE_MODEL_IDS = [
|
14 |
'svdiff-library/svdiff_dog_example',
|
15 |
+
'svdiff-library/svdiff_chair_example',
|
16 |
]
|
17 |
|
18 |
|
app_training.py
CHANGED
@@ -64,7 +64,7 @@ def create_training_demo(trainer: Trainer,
|
|
64 |
label='Resolution')
|
65 |
num_training_steps = gr.Number(
|
66 |
label='Number of Training Steps', value=1000, precision=0)
|
67 |
-
learning_rate = gr.Number(label='Learning Rate', value=0.
|
68 |
gradient_accumulation = gr.Number(
|
69 |
label='Number of Gradient Accumulation',
|
70 |
value=1,
|
@@ -91,7 +91,6 @@ def create_training_demo(trainer: Trainer,
|
|
91 |
gr.Markdown('''
|
92 |
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
|
93 |
- It takes a few minutes to download the base model first.
|
94 |
-
- It will take about 8 minutes to train for 1000 steps with a T4 GPU.
|
95 |
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
96 |
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
|
97 |
- You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
|
|
|
64 |
label='Resolution')
|
65 |
num_training_steps = gr.Number(
|
66 |
label='Number of Training Steps', value=1000, precision=0)
|
67 |
+
learning_rate = gr.Number(label='Learning Rate', value=0.001)
|
68 |
gradient_accumulation = gr.Number(
|
69 |
label='Number of Gradient Accumulation',
|
70 |
value=1,
|
|
|
91 |
gr.Markdown('''
|
92 |
- The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
|
93 |
- It takes a few minutes to download the base model first.
|
|
|
94 |
- You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
|
95 |
- You can check the training status by pressing the "Open logs" button if you are running this on your Space.
|
96 |
- You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
|
inference.py
CHANGED
@@ -8,7 +8,7 @@ import PIL.Image
|
|
8 |
import torch
|
9 |
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
10 |
from huggingface_hub import ModelCard
|
11 |
-
from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING, image_grid
|
12 |
|
13 |
|
14 |
|
@@ -58,18 +58,26 @@ class InferencePipeline:
|
|
58 |
for module in unet.modules():
|
59 |
if hasattr(module, "perform_svd"):
|
60 |
module.perform_svd()
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
if base_model_id != self.base_model_id:
|
63 |
if self.device.type == 'cpu':
|
64 |
pipe = DiffusionPipeline.from_pretrained(
|
65 |
base_model_id,
|
66 |
unet=unet,
|
|
|
67 |
use_auth_token=self.hf_token
|
68 |
)
|
69 |
else:
|
70 |
pipe = DiffusionPipeline.from_pretrained(
|
71 |
base_model_id,
|
72 |
unet=unet,
|
|
|
73 |
torch_dtype=torch.float16,
|
74 |
use_auth_token=self.hf_token
|
75 |
)
|
|
|
8 |
import torch
|
9 |
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
10 |
from huggingface_hub import ModelCard
|
11 |
+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid
|
12 |
|
13 |
|
14 |
|
|
|
58 |
for module in unet.modules():
|
59 |
if hasattr(module, "perform_svd"):
|
60 |
module.perform_svd()
|
61 |
+
if self.device.type != 'cpu':
|
62 |
+
unet = unet.to(self.device, dtype=torch.float16)
|
63 |
+
text_encoder = load_text_encoder_for_svdiff(base_model_id, spectral_shifts_ckpt=model_id, subfolder="text_encoder")
|
64 |
+
if self.device.type != 'cpu':
|
65 |
+
text_encoder = text_encoder.to(self.device, dtype=torch.float16)
|
66 |
+
else:
|
67 |
+
text_encoder = text_encoder.to(self.device)
|
68 |
if base_model_id != self.base_model_id:
|
69 |
if self.device.type == 'cpu':
|
70 |
pipe = DiffusionPipeline.from_pretrained(
|
71 |
base_model_id,
|
72 |
unet=unet,
|
73 |
+
text_encoder=text_encoder,
|
74 |
use_auth_token=self.hf_token
|
75 |
)
|
76 |
else:
|
77 |
pipe = DiffusionPipeline.from_pretrained(
|
78 |
base_model_id,
|
79 |
unet=unet,
|
80 |
+
text_encoder=text_encoder,
|
81 |
torch_dtype=torch.float16,
|
82 |
use_auth_token=self.hf_token
|
83 |
)
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
svdiff-pytorch
|
2 |
bitsandbytes==0.35.0
|
3 |
python-slugify==7.0.0
|
4 |
tomesd
|
|
|
1 |
+
svdiff-pytorch>=0.2.0
|
2 |
bitsandbytes==0.35.0
|
3 |
python-slugify==7.0.0
|
4 |
tomesd
|
train_svdiff.py
CHANGED
@@ -7,6 +7,7 @@ import warnings
|
|
7 |
from pathlib import Path
|
8 |
from typing import Optional
|
9 |
from packaging import version
|
|
|
10 |
|
11 |
import numpy as np
|
12 |
import torch
|
@@ -22,7 +23,7 @@ from PIL import Image
|
|
22 |
from torch.utils.data import Dataset
|
23 |
from torchvision import transforms
|
24 |
from tqdm.auto import tqdm
|
25 |
-
from transformers import AutoTokenizer, PretrainedConfig
|
26 |
|
27 |
import diffusers
|
28 |
from diffusers import __version__
|
@@ -33,7 +34,7 @@ from diffusers import (
|
|
33 |
StableDiffusionPipeline,
|
34 |
DPMSolverMultistepScheduler,
|
35 |
)
|
36 |
-
from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING
|
37 |
from diffusers.loaders import AttnProcsLayers
|
38 |
from diffusers.optimization import get_scheduler
|
39 |
from diffusers.utils import check_min_version, is_wandb_available
|
@@ -78,26 +79,6 @@ These are SVDiff weights for {base_model}. The weights were trained on {prompt}
|
|
78 |
f.write(yaml + model_card)
|
79 |
|
80 |
|
81 |
-
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
82 |
-
text_encoder_config = PretrainedConfig.from_pretrained(
|
83 |
-
pretrained_model_name_or_path,
|
84 |
-
subfolder="text_encoder",
|
85 |
-
revision=revision,
|
86 |
-
)
|
87 |
-
model_class = text_encoder_config.architectures[0]
|
88 |
-
|
89 |
-
if model_class == "CLIPTextModel":
|
90 |
-
from transformers import CLIPTextModel
|
91 |
-
|
92 |
-
return CLIPTextModel
|
93 |
-
elif model_class == "RobertaSeriesModelWithTransformation":
|
94 |
-
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
95 |
-
|
96 |
-
return RobertaSeriesModelWithTransformation
|
97 |
-
else:
|
98 |
-
raise ValueError(f"{model_class} is not supported.")
|
99 |
-
|
100 |
-
|
101 |
def parse_args(input_args=None):
|
102 |
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
103 |
parser.add_argument(
|
@@ -271,9 +252,15 @@ def parse_args(input_args=None):
|
|
271 |
parser.add_argument(
|
272 |
"--learning_rate",
|
273 |
type=float,
|
274 |
-
default=
|
275 |
help="Initial learning rate (after the potential warmup period) to use.",
|
276 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
parser.add_argument(
|
278 |
"--scale_lr",
|
279 |
action="store_true",
|
@@ -380,6 +367,11 @@ def parse_args(input_args=None):
|
|
380 |
parser.add_argument(
|
381 |
"--enable_token_merging", action="store_true", help="Whether or not to use tomesd on prior generation"
|
382 |
)
|
|
|
|
|
|
|
|
|
|
|
383 |
if input_args is not None:
|
384 |
args = parser.parse_args(input_args)
|
385 |
else:
|
@@ -594,6 +586,11 @@ def main(args):
|
|
594 |
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
595 |
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
596 |
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
|
|
|
|
|
|
|
|
|
|
597 |
# Make one log on every process with the configuration for debugging.
|
598 |
logging.basicConfig(
|
599 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
@@ -700,14 +697,14 @@ def main(args):
|
|
700 |
use_fast=False,
|
701 |
)
|
702 |
|
703 |
-
# import correct text encoder class
|
704 |
-
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
705 |
-
|
706 |
# Load scheduler and models
|
707 |
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
708 |
-
|
709 |
-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
710 |
-
|
|
|
|
|
|
|
711 |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
712 |
unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=True)
|
713 |
|
@@ -716,26 +713,26 @@ def main(args):
|
|
716 |
text_encoder.requires_grad_(False)
|
717 |
unet.requires_grad_(False)
|
718 |
optim_params = []
|
|
|
719 |
for n, p in unet.named_parameters():
|
720 |
if "delta" in n:
|
721 |
p.requires_grad = True
|
722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
723 |
total_params = sum(p.numel() for p in optim_params)
|
724 |
print(f"Number of Trainable Parameters: {total_params * 1.e-6:.2f} M")
|
725 |
|
726 |
-
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
727 |
-
# as these models are only used for inference, keeping weights in full precision is not required.
|
728 |
-
weight_dtype = torch.float32
|
729 |
-
if accelerator.mixed_precision == "fp16":
|
730 |
-
weight_dtype = torch.float16
|
731 |
-
elif accelerator.mixed_precision == "bf16":
|
732 |
-
weight_dtype = torch.bfloat16
|
733 |
-
|
734 |
-
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
735 |
-
# unet.to(accelerator.device, dtype=weight_dtype)
|
736 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
737 |
-
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
738 |
-
|
739 |
if args.enable_xformers_memory_efficient_attention:
|
740 |
if is_xformers_available():
|
741 |
import xformers
|
@@ -751,12 +748,26 @@ def main(args):
|
|
751 |
|
752 |
if args.gradient_checkpointing:
|
753 |
unet.enable_gradient_checkpointing()
|
|
|
|
|
754 |
|
755 |
-
|
756 |
-
|
757 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
)
|
759 |
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
# Enable TF32 for faster training on Ampere GPUs,
|
761 |
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
762 |
if args.allow_tf32:
|
@@ -782,7 +793,7 @@ def main(args):
|
|
782 |
|
783 |
# Optimizer creation
|
784 |
optimizer = optimizer_class(
|
785 |
-
optim_params,
|
786 |
lr=args.learning_rate,
|
787 |
betas=(args.adam_beta1, args.adam_beta2),
|
788 |
weight_decay=args.adam_weight_decay,
|
@@ -826,9 +837,29 @@ def main(args):
|
|
826 |
)
|
827 |
|
828 |
# Prepare everything with our `accelerator`.
|
829 |
-
|
830 |
-
unet, optimizer, train_dataloader, lr_scheduler
|
831 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
832 |
|
833 |
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
834 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
@@ -842,14 +873,27 @@ def main(args):
|
|
842 |
if accelerator.is_main_process:
|
843 |
accelerator.init_trackers("svdiff-pytorch", config=vars(args))
|
844 |
|
845 |
-
|
|
|
|
|
|
|
|
|
|
|
846 |
# Create the pipeline using using the trained modules and save it.
|
847 |
if accelerator.is_main_process:
|
848 |
-
save_path
|
|
|
849 |
os.makedirs(save_path, exist_ok=True)
|
850 |
-
|
851 |
-
state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
|
|
|
852 |
save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
|
|
|
|
|
|
|
|
|
|
|
|
|
853 |
print(f"[*] Weights saved at {save_path}")
|
854 |
|
855 |
# Train!
|
@@ -897,6 +941,8 @@ def main(args):
|
|
897 |
|
898 |
for epoch in range(first_epoch, args.num_train_epochs):
|
899 |
unet.train()
|
|
|
|
|
900 |
for step, batch in enumerate(train_dataloader):
|
901 |
# Skip steps until we reach the resumed step
|
902 |
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
@@ -952,7 +998,11 @@ def main(args):
|
|
952 |
|
953 |
accelerator.backward(loss)
|
954 |
if accelerator.sync_gradients:
|
955 |
-
params_to_clip =
|
|
|
|
|
|
|
|
|
956 |
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
957 |
optimizer.step()
|
958 |
lr_scheduler.step()
|
@@ -970,7 +1020,7 @@ def main(args):
|
|
970 |
# accelerator.save_state(save_path)
|
971 |
# logger.info(f"Saved state to {save_path}")
|
972 |
|
973 |
-
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
974 |
progress_bar.set_postfix(**logs)
|
975 |
accelerator.log(logs, step=global_step)
|
976 |
|
@@ -982,14 +1032,8 @@ def main(args):
|
|
982 |
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
|
983 |
|
984 |
accelerator.wait_for_everyone()
|
985 |
-
|
986 |
-
|
987 |
-
save_path = args.output_dir
|
988 |
-
unet_model = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
|
989 |
-
state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
|
990 |
-
save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
|
991 |
-
print(f"[*] Weights saved at {save_path}")
|
992 |
-
|
993 |
if accelerator.is_main_process:
|
994 |
if args.push_to_hub:
|
995 |
save_model_card(
|
|
|
7 |
from pathlib import Path
|
8 |
from typing import Optional
|
9 |
from packaging import version
|
10 |
+
import itertools
|
11 |
|
12 |
import numpy as np
|
13 |
import torch
|
|
|
23 |
from torch.utils.data import Dataset
|
24 |
from torchvision import transforms
|
25 |
from tqdm.auto import tqdm
|
26 |
+
from transformers import CLIPTextModel, AutoTokenizer, PretrainedConfig
|
27 |
|
28 |
import diffusers
|
29 |
from diffusers import __version__
|
|
|
34 |
StableDiffusionPipeline,
|
35 |
DPMSolverMultistepScheduler,
|
36 |
)
|
37 |
+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING
|
38 |
from diffusers.loaders import AttnProcsLayers
|
39 |
from diffusers.optimization import get_scheduler
|
40 |
from diffusers.utils import check_min_version, is_wandb_available
|
|
|
79 |
f.write(yaml + model_card)
|
80 |
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def parse_args(input_args=None):
|
83 |
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
84 |
parser.add_argument(
|
|
|
252 |
parser.add_argument(
|
253 |
"--learning_rate",
|
254 |
type=float,
|
255 |
+
default=1e-3,
|
256 |
help="Initial learning rate (after the potential warmup period) to use.",
|
257 |
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--learning_rate_1d",
|
260 |
+
type=float,
|
261 |
+
default=1e-6,
|
262 |
+
help="Initial learning rate (after the potential warmup period) to use for 1-d weights",
|
263 |
+
)
|
264 |
parser.add_argument(
|
265 |
"--scale_lr",
|
266 |
action="store_true",
|
|
|
367 |
parser.add_argument(
|
368 |
"--enable_token_merging", action="store_true", help="Whether or not to use tomesd on prior generation"
|
369 |
)
|
370 |
+
parser.add_argument(
|
371 |
+
"--train_text_encoder",
|
372 |
+
action="store_true",
|
373 |
+
help="Whether to train spectral shifts of the text encoder. If set, the text encoder should be float32 precision.",
|
374 |
+
)
|
375 |
if input_args is not None:
|
376 |
args = parser.parse_args(input_args)
|
377 |
else:
|
|
|
586 |
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
587 |
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
588 |
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
589 |
+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
590 |
+
raise ValueError(
|
591 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
592 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
593 |
+
)
|
594 |
# Make one log on every process with the configuration for debugging.
|
595 |
logging.basicConfig(
|
596 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
697 |
use_fast=False,
|
698 |
)
|
699 |
|
|
|
|
|
|
|
700 |
# Load scheduler and models
|
701 |
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
702 |
+
if args.train_text_encoder:
|
703 |
+
text_encoder = load_text_encoder_for_svdiff(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
|
704 |
+
else:
|
705 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
706 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
707 |
+
)
|
708 |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
709 |
unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=True)
|
710 |
|
|
|
713 |
text_encoder.requires_grad_(False)
|
714 |
unet.requires_grad_(False)
|
715 |
optim_params = []
|
716 |
+
optim_params_1d = []
|
717 |
for n, p in unet.named_parameters():
|
718 |
if "delta" in n:
|
719 |
p.requires_grad = True
|
720 |
+
if "norm" in n:
|
721 |
+
optim_params_1d.append(p)
|
722 |
+
else:
|
723 |
+
optim_params.append(p)
|
724 |
+
if args.train_text_encoder:
|
725 |
+
for n, p in text_encoder.named_parameters():
|
726 |
+
if "delta" in n:
|
727 |
+
p.requires_grad = True
|
728 |
+
if "norm" in n:
|
729 |
+
optim_params_1d.append(p)
|
730 |
+
else:
|
731 |
+
optim_params.append(p)
|
732 |
+
|
733 |
total_params = sum(p.numel() for p in optim_params)
|
734 |
print(f"Number of Trainable Parameters: {total_params * 1.e-6:.2f} M")
|
735 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
736 |
if args.enable_xformers_memory_efficient_attention:
|
737 |
if is_xformers_available():
|
738 |
import xformers
|
|
|
748 |
|
749 |
if args.gradient_checkpointing:
|
750 |
unet.enable_gradient_checkpointing()
|
751 |
+
if args.train_text_encoder:
|
752 |
+
text_encoder.gradient_checkpointing_enable()
|
753 |
|
754 |
+
# Check that all trainable models are in full precision
|
755 |
+
low_precision_error_string = (
|
756 |
+
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
|
757 |
+
" doing mixed precision training. copy of the weights should still be float32."
|
758 |
+
)
|
759 |
+
|
760 |
+
if accelerator.unwrap_model(unet).dtype != torch.float32:
|
761 |
+
raise ValueError(
|
762 |
+
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
763 |
)
|
764 |
|
765 |
+
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
|
766 |
+
raise ValueError(
|
767 |
+
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
|
768 |
+
f" {low_precision_error_string}"
|
769 |
+
)
|
770 |
+
|
771 |
# Enable TF32 for faster training on Ampere GPUs,
|
772 |
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
773 |
if args.allow_tf32:
|
|
|
793 |
|
794 |
# Optimizer creation
|
795 |
optimizer = optimizer_class(
|
796 |
+
[{"params": optim_params}, {"params": optim_params_1d, "lr": args.learning_rate_1d}],
|
797 |
lr=args.learning_rate,
|
798 |
betas=(args.adam_beta1, args.adam_beta2),
|
799 |
weight_decay=args.adam_weight_decay,
|
|
|
837 |
)
|
838 |
|
839 |
# Prepare everything with our `accelerator`.
|
840 |
+
if args.train_text_encoder:
|
841 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
842 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
843 |
+
)
|
844 |
+
else:
|
845 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
846 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
847 |
+
)
|
848 |
+
|
849 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
850 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
851 |
+
weight_dtype = torch.float32
|
852 |
+
if accelerator.mixed_precision == "fp16":
|
853 |
+
weight_dtype = torch.float16
|
854 |
+
elif accelerator.mixed_precision == "bf16":
|
855 |
+
weight_dtype = torch.bfloat16
|
856 |
+
|
857 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
858 |
+
# unet.to(accelerator.device, dtype=weight_dtype)
|
859 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
860 |
+
if not args.train_text_encoder:
|
861 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
862 |
+
|
863 |
|
864 |
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
865 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
|
|
873 |
if accelerator.is_main_process:
|
874 |
accelerator.init_trackers("svdiff-pytorch", config=vars(args))
|
875 |
|
876 |
+
# cache keys to save
|
877 |
+
state_dict_keys = [k for k in accelerator.unwrap_model(unet).state_dict().keys() if "delta" in k]
|
878 |
+
if args.train_text_encoder:
|
879 |
+
state_dict_keys_te = [k for k in accelerator.unwrap_model(text_encoder).state_dict().keys() if "delta" in k]
|
880 |
+
|
881 |
+
def save_weights(step, save_path=None):
|
882 |
# Create the pipeline using using the trained modules and save it.
|
883 |
if accelerator.is_main_process:
|
884 |
+
if save_path is None:
|
885 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{step}")
|
886 |
os.makedirs(save_path, exist_ok=True)
|
887 |
+
state_dict = accelerator.unwrap_model(unet, keep_fp32_wrapper=True).state_dict()
|
888 |
+
# state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
|
889 |
+
state_dict = {k: state_dict[k] for k in state_dict_keys}
|
890 |
save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
|
891 |
+
if args.train_text_encoder:
|
892 |
+
state_dict = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).state_dict()
|
893 |
+
# state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
|
894 |
+
state_dict = {k: state_dict[k] for k in state_dict_keys_te}
|
895 |
+
save_file(state_dict, os.path.join(save_path, "spectral_shifts_te.safetensors"))
|
896 |
+
|
897 |
print(f"[*] Weights saved at {save_path}")
|
898 |
|
899 |
# Train!
|
|
|
941 |
|
942 |
for epoch in range(first_epoch, args.num_train_epochs):
|
943 |
unet.train()
|
944 |
+
if args.train_text_encoder:
|
945 |
+
text_encoder.train()
|
946 |
for step, batch in enumerate(train_dataloader):
|
947 |
# Skip steps until we reach the resumed step
|
948 |
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
|
|
998 |
|
999 |
accelerator.backward(loss)
|
1000 |
if accelerator.sync_gradients:
|
1001 |
+
params_to_clip = (
|
1002 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
1003 |
+
if args.train_text_encoder
|
1004 |
+
else unet.parameters()
|
1005 |
+
)
|
1006 |
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
1007 |
optimizer.step()
|
1008 |
lr_scheduler.step()
|
|
|
1020 |
# accelerator.save_state(save_path)
|
1021 |
# logger.info(f"Saved state to {save_path}")
|
1022 |
|
1023 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "lr_1d": lr_scheduler.get_last_lr()[1]}
|
1024 |
progress_bar.set_postfix(**logs)
|
1025 |
accelerator.log(logs, step=global_step)
|
1026 |
|
|
|
1032 |
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
|
1033 |
|
1034 |
accelerator.wait_for_everyone()
|
1035 |
+
# put the last checkpoint to output-dir
|
1036 |
+
save_weights(global_step, save_path=args.output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
1037 |
if accelerator.is_main_process:
|
1038 |
if args.push_to_hub:
|
1039 |
save_model_card(
|
trainer.py
CHANGED
@@ -124,6 +124,8 @@ class Trainer:
|
|
124 |
--train_batch_size=1 \
|
125 |
--gradient_accumulation_steps={gradient_accumulation} \
|
126 |
--learning_rate={learning_rate} \
|
|
|
|
|
127 |
--lr_scheduler=constant \
|
128 |
--lr_warmup_steps=0 \
|
129 |
--max_train_steps={n_steps} \
|
|
|
124 |
--train_batch_size=1 \
|
125 |
--gradient_accumulation_steps={gradient_accumulation} \
|
126 |
--learning_rate={learning_rate} \
|
127 |
+
--learning_rate_1d=1e-6 \
|
128 |
+
--train_text_encoder \
|
129 |
--lr_scheduler=constant \
|
130 |
--lr_warmup_steps=0 \
|
131 |
--max_train_steps={n_steps} \
|