diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bee8a64b79a99590d5303307144172cfe824fbf7 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..217dd15ba99720de8aaecb80ce61e32b008ed278 --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +from trainscripts.textsliders import lora \ No newline at end of file diff --git a/app.py b/app.py index 3703e2db0009fea1686d779101b431c47248e5e9..4af54e9d1a93575b7bae3b3a2821e05ea25f8a6f 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,264 @@ import gradio as gr +import torch +import os +from utils import call +from diffusers.pipelines import StableDiffusionXLPipeline +StableDiffusionXLPipeline.__call__ = call -def greet(name): - return "Hello " + name + "!!" +model_map = {'Age' : 'models/age.pt', + 'Chubby': 'models/chubby.pt', + 'Muscular': 'models/muscular.pt', + 'Wavy Eyebrows': 'models/eyebrows.pt', + 'Small Eyes': 'models/eyesize.pt', + 'Long Hair' : 'models/longhair.pt', + 'Curly Hair' : 'models/curlyhair.pt', + 'Smiling' : 'models/smiling.pt', + 'Pixar Style' : 'models/pixar_style.pt', + 'Sculpture Style': 'models/sculpture_style.pt', + 'Repair Images': 'models/repair_slider.pt', + 'Fix Hands': 'models/fix_hands.pt', + } -iface = gr.Interface(fn=greet, inputs="text", outputs="text") -iface.launch() +ORIGINAL_SPACE_ID = 'baulab/ConceptSliders' +SPACE_ID = os.getenv('SPACE_ID') + +SHARED_UI_WARNING = f'''## Attention - Training does not work in this shared UI. You can either duplicate and use it with a gpu with at least 40GB, or clone this repository to run on your own machine. +
Duplicate Space
+''' + + +class Demo: + + def __init__(self) -> None: + + self.training = False + self.generating = False + self.device = 'cuda' + self.weight_dtype = torch.float16 + self.pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=weight_dtype) + + with gr.Blocks() as demo: + self.layout() + demo.queue(concurrency_count=5).launch() + + + def layout(self): + + with gr.Row(): + + if SPACE_ID == ORIGINAL_SPACE_ID: + + self.warning = gr.Markdown(SHARED_UI_WARNING) + + with gr.Row(): + + with gr.Tab("Test") as inference_column: + + with gr.Row(): + + self.explain_infr = gr.Markdown(interactive=False, + value='This is a demo of [Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models](https://sliders.baulab.info/). To try out a model that can control a particular concept, select a model and enter any prompt. For example, if you select the model "Surprised Look" you can generate images for the prompt "A picture of a person, realistic, 8k" and compare the slider effect to the image generated by original model. We have also provided several other pre-fine-tuned models like "repair" sliders to repair flaws in SDXL generated images (Check out the "Pretrained Sliders" drop-down). You can also train and run your own custom sliders. Check out the "train" section for custom concept slider training.') + + with gr.Row(): + + with gr.Column(scale=1): + + self.prompt_input_infr = gr.Text( + placeholder="Enter prompt...", + label="Prompt", + info="Prompt to generate" + ) + + with gr.Row(): + + self.model_dropdown = gr.Dropdown( + label="Pretrained Sliders", + choices= list(model_map.keys()), + value='Age', + interactive=True + ) + + self.seed_infr = gr.Number( + label="Seed", + value=12345 + ) + + with gr.Column(scale=2): + + self.infr_button = gr.Button( + value="Generate", + interactive=True + ) + + with gr.Row(): + + self.image_new = gr.Image( + label="Slider", + interactive=False + ) + self.image_orig = gr.Image( + label="Original SD", + interactive=False + ) + + with gr.Tab("Train") as training_column: + + with gr.Row(): + + self.explain_train= gr.Markdown(interactive=False, + value='In this part you can train a concept slider for Stable Diffusion XL. Enter a target concept you wish to make an edit on. Next, enter a enhance prompt of the attribute you wish to edit (for controlling age of a person, enter "person, old"). Then, type the supress prompt of the attribute (for our example, enter "person, young"). Then press "train" button. With default settings, it takes about 15 minutes to train a slider; then you can try inference above or download the weights. Code and details are at [github link](https://github.com/rohitgandikota/sliders).') + + with gr.Row(): + + with gr.Column(scale=3): + + self.target_concept = gr.Text( + placeholder="Enter target concept to make edit on ...", + label="Prompt of concept on which edit is made", + info="Prompt corresponding to concept to edit" + ) + + self.positive_prompt = gr.Text( + placeholder="Enter the enhance prompt for the edit...", + label="Prompt to enhance", + info="Prompt corresponding to concept to enhance" + ) + + self.negative_prompt = gr.Text( + placeholder="Enter the suppress prompt for the edit...", + label="Prompt to suppress", + info="Prompt corresponding to concept to supress" + ) + + + self.rank = gr.Number( + value=4, + label="Rank of the Slider", + info='Slider Rank to train' + ) + + self.iterations_input = gr.Number( + value=1000, + precision=0, + label="Iterations", + info='iterations used to train' + ) + + self.lr_input = gr.Number( + value=2e-4, + label="Learning Rate", + info='Learning rate used to train' + ) + + with gr.Column(scale=1): + + self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False) + + self.train_button = gr.Button( + value="Train", + ) + + self.download = gr.Files() + + self.infr_button.click(self.inference, inputs = [ + self.prompt_input_infr, + self.seed_infr, + self.model_dropdown + ], + outputs=[ + self.image_new, + self.image_orig + ] + ) + self.train_button.click(self.train, inputs = [ + self.target_concept, + self.positive_prompt, + slef.negative_prompt, + self.rank, + self.iterations_input, + self.lr_input + ], + outputs=[self.train_button, self.train_status, self.download, self.model_dropdown] + ) + + def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)): + + if self.training: + return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()] + + if train_method == 'ESD-x': + + modules = ".*attn2$" + frozen = [] + + elif train_method == 'ESD-u': + + modules = "unet$" + frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"] + + elif train_method == 'ESD-self': + + modules = ".*attn1$" + frozen = [] + + randn = torch.randint(1, 10000000, (1,)).item() + + save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt" + + self.training = True + + train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path) + + self.training = False + + torch.cuda.empty_cache() + + model_map['Custom'] = save_path + + return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')] + + + def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)): + + seed = seed or 12345 + + generator = torch.manual_seed(seed) + + model_path = model_map[model_name] + + checkpoint = torch.load(model_path) + + finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half() + + torch.cuda.empty_cache() + + images = self.diffuser( + prompt, + n_steps=50, + generator=generator + ) + + + orig_image = images[0][0] + + torch.cuda.empty_cache() + + generator = torch.manual_seed(seed) + + with finetuner: + + images = self.diffuser( + prompt, + n_steps=50, + generator=generator + ) + + edited_image = images[0][0] + + del finetuner + torch.cuda.empty_cache() + + return edited_image, orig_image + + +demo = Demo() diff --git a/models/age.pt b/models/age.pt new file mode 100644 index 0000000000000000000000000000000000000000..73e6f1df5932d6cf4642771f93abe1aabc571c9c --- /dev/null +++ b/models/age.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c1c096f7cc1109b4072cbc604c811a5f0ff034fc0f6dc7cf66a558550aa4890 +size 9142347 diff --git a/models/cartoon_style.pt b/models/cartoon_style.pt new file mode 100644 index 0000000000000000000000000000000000000000..30693de4bcb5b568acaa6b4cac9e8a92a70b12f1 --- /dev/null +++ b/models/cartoon_style.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e07c30e4f82f709a474ae11dc5108ac48f81b6996b937757c8dd198920ea9b4d +size 9146507 diff --git a/models/chubby.pt b/models/chubby.pt new file mode 100644 index 0000000000000000000000000000000000000000..aa53e4b9514c06d59cdcb030db58e8f49f10a223 --- /dev/null +++ b/models/chubby.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a70fb34187821a06a39bf36baa400090a32758d56771c3f54fcc4d9089f0d88 +size 9144427 diff --git a/models/clay_style.pt b/models/clay_style.pt new file mode 100644 index 0000000000000000000000000000000000000000..0d6e2254092e21c747dcd8f22b46f0d7fc7b5b47 --- /dev/null +++ b/models/clay_style.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b0deeb787248811fb8e54498768e303cffaeb3125d00c5fd303294af59a9380 +size 9143387 diff --git a/models/cluttered_room.pt b/models/cluttered_room.pt new file mode 100644 index 0000000000000000000000000000000000000000..e0d256bad32422471e8ed837927c3641795f2ee4 --- /dev/null +++ b/models/cluttered_room.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee409a45bfaa7ca01fbffe63ec185c0f5ccf0e7b0fa67070a9e0b41886b7ea66 +size 9140267 diff --git a/models/curlyhair.pt b/models/curlyhair.pt new file mode 100644 index 0000000000000000000000000000000000000000..8027862c385c3102fe47ac8a7323dc560fc95ffe --- /dev/null +++ b/models/curlyhair.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9b8d7d44da256291e3710f74954d352160ade5cbe291bce16c8f4951db31e7b +size 9136043 diff --git a/models/dark_weather.pt b/models/dark_weather.pt new file mode 100644 index 0000000000000000000000000000000000000000..e6c9ea0ed14369bda7a8811ce1cfe696bebf0cf3 --- /dev/null +++ b/models/dark_weather.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eecd2ae8b35022cbfb9c32637d9fa8c3c0ca3aa5ea189369c027f938064ada3c +size 9135003 diff --git a/models/eyebrow.pt b/models/eyebrow.pt new file mode 100644 index 0000000000000000000000000000000000000000..a369d41847caa9e764cbb06808a4f1b5bf6019d8 --- /dev/null +++ b/models/eyebrow.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:442770d2c30de92e30a1c2fcf9aab6b6cf5a3786eff84d513b7455345c79b57d +size 9135003 diff --git a/models/eyesize.pt b/models/eyesize.pt new file mode 100644 index 0000000000000000000000000000000000000000..f391212a5afc6cd2ec7f656f047e4a27053fcca4 --- /dev/null +++ b/models/eyesize.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fdffa3e7788f4bd6be9a2fe3b91957b4f35999fc9fa19eabfb49f92fbf6650b +size 9139227 diff --git a/models/festive.pt b/models/festive.pt new file mode 100644 index 0000000000000000000000000000000000000000..92c0c119ac2caa7ef6a51e3ffaf9417c2454c70b --- /dev/null +++ b/models/festive.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70d6c5d5be5f001510988852c2d233a916d23766675d9a000c6f785cd7e9127c +size 9133963 diff --git a/models/fix_hands.pt b/models/fix_hands.pt new file mode 100644 index 0000000000000000000000000000000000000000..0f25fb85e10b64ad6b07e884d35ae3cc1dcefb12 --- /dev/null +++ b/models/fix_hands.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d98c4828468c8d5831c439f49914672710f63219a561b191670fa54d542fa57b +size 9131883 diff --git a/models/long_hair.pt b/models/long_hair.pt new file mode 100644 index 0000000000000000000000000000000000000000..bd29a84f382cf21549a7820b0f1f222f385ffc8d --- /dev/null +++ b/models/long_hair.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e93dba27fa012bba0ea468eb2f9877ec0934424a9474e30ac9e94ea0517822ca +size 9147547 diff --git a/models/muscular.pt b/models/muscular.pt new file mode 100644 index 0000000000000000000000000000000000000000..3ba5ef81e74771d7ef92c94d680267df747d0e81 --- /dev/null +++ b/models/muscular.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b46b8eeac992f2d0e76ce887ea45ec1ce70bfbae053876de26d1f33f986eb37 +size 9135003 diff --git a/models/pixar_style.pt b/models/pixar_style.pt new file mode 100644 index 0000000000000000000000000000000000000000..30693de4bcb5b568acaa6b4cac9e8a92a70b12f1 --- /dev/null +++ b/models/pixar_style.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e07c30e4f82f709a474ae11dc5108ac48f81b6996b937757c8dd198920ea9b4d +size 9146507 diff --git a/models/professional.pt b/models/professional.pt new file mode 100644 index 0000000000000000000000000000000000000000..bf6b20c64205439cc7b2ad465765e96c5ebe9ca7 --- /dev/null +++ b/models/professional.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d4289f4c60dd008fe487369ddccf3492bd678cc1e6b30de2c17f9ce802b12ac +size 9151707 diff --git a/models/repair_slider.pt b/models/repair_slider.pt new file mode 100644 index 0000000000000000000000000000000000000000..bb6f2457cb434b455bdbb9001dffc6ad8d3aea95 --- /dev/null +++ b/models/repair_slider.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6e589e7d3b2174bb1d5d861a7218c4c26a94425b6dcdce0085b57f87ab841c5 +size 9133963 diff --git a/models/sculpture_style.pt b/models/sculpture_style.pt new file mode 100644 index 0000000000000000000000000000000000000000..1a0e83378c3aae9c73b1b23f762c2477d960be24 --- /dev/null +++ b/models/sculpture_style.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2779746c08062ccb128fdaa6cb66f061070ac8f19386701a99fb9291392d5985 +size 9148587 diff --git a/models/smiling.pt b/models/smiling.pt new file mode 100644 index 0000000000000000000000000000000000000000..d89e16385e268a360cef81e980a5cda6d6904696 --- /dev/null +++ b/models/smiling.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6430ab47393ba15222ea0988c3479f547c8b59f93a41024bcddd7121ef7147d1 +size 9146507 diff --git a/models/stylegan_latent1.pt b/models/stylegan_latent1.pt new file mode 100644 index 0000000000000000000000000000000000000000..12296125e02088c1d5c81979dfd5a1df42faafdb --- /dev/null +++ b/models/stylegan_latent1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dca6cda8028af4587968cfed07c3bc6a2e79e5f8d01dad9351877f9de28f232d +size 9142347 diff --git a/models/stylegan_latent2.pt b/models/stylegan_latent2.pt new file mode 100644 index 0000000000000000000000000000000000000000..baf1573d937077e7d28d0979a7893b5f40d46bf3 --- /dev/null +++ b/models/stylegan_latent2.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bbe239c399a4fc7b73a034b643c406106cd4c8392ad806ee3fd8dd8c80ba5fc +size 9142347 diff --git a/models/suprised_look.pt b/models/suprised_look.pt new file mode 100644 index 0000000000000000000000000000000000000000..17be2279f20941039a7c8fd9737a46a1966f3c3d --- /dev/null +++ b/models/suprised_look.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36806271ca61dced2a506430c6c0b53ace09c68f65a90e09778c2bb5bcad31d4 +size 9148587 diff --git a/models/tropical_weather.pt b/models/tropical_weather.pt new file mode 100644 index 0000000000000000000000000000000000000000..ac54135d25bda0c86c4962a0c59c4abe66dfe427 --- /dev/null +++ b/models/tropical_weather.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:215e5445bbb7288ebea2e523181ca6db991417deca2736de29f0c3a76eb69ac0 +size 9135003 diff --git a/models/winter_weather.pt b/models/winter_weather.pt new file mode 100644 index 0000000000000000000000000000000000000000..ea3d0346e077462ac86afe59740e4b537a536a5a --- /dev/null +++ b/models/winter_weather.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38f0bc81bc3cdef0c1c47895df6c9f0a9b98507f48928ef971f341e02c76bb4c +size 9132923 diff --git a/requirements.txt b/reqs.txt similarity index 100% rename from requirements.txt rename to reqs.txt diff --git a/trainscripts/__init__.py b/trainscripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..587b6a2c2889cf9e7a3b39da81572b239000b45c --- /dev/null +++ b/trainscripts/__init__.py @@ -0,0 +1 @@ +# from textsliders import lora \ No newline at end of file diff --git a/trainscripts/imagesliders/config_util.py b/trainscripts/imagesliders/config_util.py new file mode 100644 index 0000000000000000000000000000000000000000..25e184821e6db86b265a6671a16b72d91682c205 --- /dev/null +++ b/trainscripts/imagesliders/config_util.py @@ -0,0 +1,104 @@ +from typing import Literal, Optional + +import yaml + +from pydantic import BaseModel +import torch + +from lora import TRAINING_METHODS + +PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"] +NETWORK_TYPES = Literal["lierla", "c3lier"] + + +class PretrainedModelConfig(BaseModel): + name_or_path: str + v2: bool = False + v_pred: bool = False + + clip_skip: Optional[int] = None + + +class NetworkConfig(BaseModel): + type: NETWORK_TYPES = "lierla" + rank: int = 4 + alpha: float = 1.0 + + training_method: TRAINING_METHODS = "full" + + +class TrainConfig(BaseModel): + precision: PRECISION_TYPES = "bfloat16" + noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim" + + iterations: int = 500 + lr: float = 1e-4 + optimizer: str = "adamw" + optimizer_args: str = "" + lr_scheduler: str = "constant" + + max_denoising_steps: int = 50 + + +class SaveConfig(BaseModel): + name: str = "untitled" + path: str = "./output" + per_steps: int = 200 + precision: PRECISION_TYPES = "float32" + + +class LoggingConfig(BaseModel): + use_wandb: bool = False + + verbose: bool = False + + +class OtherConfig(BaseModel): + use_xformers: bool = False + + +class RootConfig(BaseModel): + prompts_file: str + pretrained_model: PretrainedModelConfig + + network: NetworkConfig + + train: Optional[TrainConfig] + + save: Optional[SaveConfig] + + logging: Optional[LoggingConfig] + + other: Optional[OtherConfig] + + +def parse_precision(precision: str) -> torch.dtype: + if precision == "fp32" or precision == "float32": + return torch.float32 + elif precision == "fp16" or precision == "float16": + return torch.float16 + elif precision == "bf16" or precision == "bfloat16": + return torch.bfloat16 + + raise ValueError(f"Invalid precision type: {precision}") + + +def load_config_from_yaml(config_path: str) -> RootConfig: + with open(config_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + root = RootConfig(**config) + + if root.train is None: + root.train = TrainConfig() + + if root.save is None: + root.save = SaveConfig() + + if root.logging is None: + root.logging = LoggingConfig() + + if root.other is None: + root.other = OtherConfig() + + return root diff --git a/trainscripts/imagesliders/data/config-xl.yaml b/trainscripts/imagesliders/data/config-xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15d4fd7f3ad65fd06406d62db540cbc35a30b11f --- /dev/null +++ b/trainscripts/imagesliders/data/config-xl.yaml @@ -0,0 +1,28 @@ +prompts_file: "trainscripts/imagesliders/data/prompts-xl.yaml" +pretrained_model: + name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models + v2: false # true if model is v2.x + v_pred: false # true if model uses v-prediction +network: + type: "c3lier" # or "c3lier" or "lierla" + rank: 4 + alpha: 1.0 + training_method: "noxattn" +train: + precision: "bfloat16" + noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" + iterations: 1000 + lr: 0.0002 + optimizer: "AdamW" + lr_scheduler: "constant" + max_denoising_steps: 50 +save: + name: "temp" + path: "./models" + per_steps: 500 + precision: "bfloat16" +logging: + use_wandb: false + verbose: false +other: + use_xformers: true \ No newline at end of file diff --git a/trainscripts/imagesliders/data/config.yaml b/trainscripts/imagesliders/data/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48ced6c3e996e953e1653708c07dbdc922a1e7db --- /dev/null +++ b/trainscripts/imagesliders/data/config.yaml @@ -0,0 +1,28 @@ +prompts_file: "trainscripts/imagesliders/data/prompts.yaml" +pretrained_model: + name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models + v2: false # true if model is v2.x + v_pred: false # true if model uses v-prediction +network: + type: "c3lier" # or "c3lier" or "lierla" + rank: 4 + alpha: 1.0 + training_method: "noxattn" +train: + precision: "bfloat16" + noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" + iterations: 1000 + lr: 0.0002 + optimizer: "AdamW" + lr_scheduler: "constant" + max_denoising_steps: 50 +save: + name: "temp" + path: "./models" + per_steps: 500 + precision: "bfloat16" +logging: + use_wandb: false + verbose: false +other: + use_xformers: true \ No newline at end of file diff --git a/trainscripts/imagesliders/data/prompts-xl.yaml b/trainscripts/imagesliders/data/prompts-xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b0bd048762021dc2374633cb556698e727426a8 --- /dev/null +++ b/trainscripts/imagesliders/data/prompts-xl.yaml @@ -0,0 +1,275 @@ +####################################################################################################### AGE SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, very old" # concept to erase +# unconditional: "male person, very young" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, very old" # concept to erase +# unconditional: "female person, very young" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### GLASSES SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, wearing glasses" # concept to erase +# unconditional: "male person" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, wearing glasses" # concept to erase +# unconditional: "female person" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### ASTRONAUGHT SLIDER +# - target: "astronaught" # what word for erasing the positive concept from +# positive: "astronaught, with orange colored spacesuit" # concept to erase +# unconditional: "astronaught" # word to take the difference from the positive concept +# neutral: "astronaught" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### SMILING SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, smiling" # concept to erase +# unconditional: "male person, frowning" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, smiling" # concept to erase +# unconditional: "female person, frowning" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### CAR COLOR SLIDER +# - target: "car" # what word for erasing the positive concept from +# positive: "car, white color" # concept to erase +# unconditional: "car, black color" # word to take the difference from the positive concept +# neutral: "car" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### DETAILS SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase +# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### BOKEH SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "blurred background, narrow DOF, bokeh effect" # concept to erase +# # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept +# unconditional: "" +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### LONG HAIR SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with long hair" # concept to erase +# unconditional: "male person, with short hair" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with long hair" # concept to erase +# unconditional: "female person, with short hair" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### IMAGE SLIDER +- target: "" # what word for erasing the positive concept from + positive: "" # concept to erase + unconditional: "" # word to take the difference from the positive concept + neutral: "" # starting point for conditioning the target + action: "enhance" # erase or enhance + guidance_scale: 4 + resolution: 512 + dynamic_resolution: false + batch_size: 1 +####################################################################################################### IMAGE SLIDER +# - target: "food" # what word for erasing the positive concept from +# positive: "food, expensive and fine dining" # concept to erase +# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept +# neutral: "food" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "room" # what word for erasing the positive concept from +# positive: "room, dirty disorganised and cluttered" # concept to erase +# unconditional: "room, neat organised and clean" # word to take the difference from the positive concept +# neutral: "room" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with a surprised look" # concept to erase +# unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with a surprised look" # concept to erase +# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "sky" # what word for erasing the positive concept from +# positive: "peaceful sky" # concept to erase +# unconditional: "sky" # word to take the difference from the positive concept +# neutral: "sky" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "sky" # what word for erasing the positive concept from +# positive: "chaotic dark sky" # concept to erase +# unconditional: "sky" # word to take the difference from the positive concept +# neutral: "sky" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "person" # what word for erasing the positive concept from +# positive: "person, very young" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# overweight +# - target: "art" # what word for erasing the positive concept from +# positive: "realistic art" # concept to erase +# unconditional: "art" # word to take the difference from the positive concept +# neutral: "art" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "art" # what word for erasing the positive concept from +# positive: "abstract art" # concept to erase +# unconditional: "art" # word to take the difference from the positive concept +# neutral: "art" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# sky +# - target: "weather" # what word for erasing the positive concept from +# positive: "bright pleasant weather" # concept to erase +# unconditional: "weather" # word to take the difference from the positive concept +# neutral: "weather" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "weather" # what word for erasing the positive concept from +# positive: "dark gloomy weather" # concept to erase +# unconditional: "weather" # word to take the difference from the positive concept +# neutral: "weather" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# hair +# - target: "person" # what word for erasing the positive concept from +# positive: "person with long hair" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "person" # what word for erasing the positive concept from +# positive: "person with short hair" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "girl" # what word for erasing the positive concept from +# positive: "baby girl" # concept to erase +# unconditional: "girl" # word to take the difference from the positive concept +# neutral: "girl" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: -4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "boy" # what word for erasing the positive concept from +# positive: "old man" # concept to erase +# unconditional: "boy" # word to take the difference from the positive concept +# neutral: "boy" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "boy" # what word for erasing the positive concept from +# positive: "baby boy" # concept to erase +# unconditional: "boy" # word to take the difference from the positive concept +# neutral: "boy" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: -4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 \ No newline at end of file diff --git a/trainscripts/imagesliders/data/prompts.yaml b/trainscripts/imagesliders/data/prompts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..09e26c5d0482a1d691da14ba6e812e680366a011 --- /dev/null +++ b/trainscripts/imagesliders/data/prompts.yaml @@ -0,0 +1,174 @@ +# - target: "person" # what word for erasing the positive concept from +# positive: "person, very old" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +- target: "" # what word for erasing the positive concept from + positive: "" # concept to erase + unconditional: "" # word to take the difference from the positive concept + neutral: "" # starting point for conditioning the target + action: "enhance" # erase or enhance + guidance_scale: 1 + resolution: 512 + dynamic_resolution: false + batch_size: 1 +# - target: "" # what word for erasing the positive concept from +# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase +# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "food" # what word for erasing the positive concept from +# positive: "food, expensive and fine dining" # concept to erase +# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept +# neutral: "food" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "room" # what word for erasing the positive concept from +# positive: "room, dirty disorganised and cluttered" # concept to erase +# unconditional: "room, neat organised and clean" # word to take the difference from the positive concept +# neutral: "room" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with a surprised look" # concept to erase +# unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with a surprised look" # concept to erase +# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "sky" # what word for erasing the positive concept from +# positive: "peaceful sky" # concept to erase +# unconditional: "sky" # word to take the difference from the positive concept +# neutral: "sky" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "sky" # what word for erasing the positive concept from +# positive: "chaotic dark sky" # concept to erase +# unconditional: "sky" # word to take the difference from the positive concept +# neutral: "sky" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "person" # what word for erasing the positive concept from +# positive: "person, very young" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# overweight +# - target: "art" # what word for erasing the positive concept from +# positive: "realistic art" # concept to erase +# unconditional: "art" # word to take the difference from the positive concept +# neutral: "art" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "art" # what word for erasing the positive concept from +# positive: "abstract art" # concept to erase +# unconditional: "art" # word to take the difference from the positive concept +# neutral: "art" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# sky +# - target: "weather" # what word for erasing the positive concept from +# positive: "bright pleasant weather" # concept to erase +# unconditional: "weather" # word to take the difference from the positive concept +# neutral: "weather" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "weather" # what word for erasing the positive concept from +# positive: "dark gloomy weather" # concept to erase +# unconditional: "weather" # word to take the difference from the positive concept +# neutral: "weather" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# hair +# - target: "person" # what word for erasing the positive concept from +# positive: "person with long hair" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "person" # what word for erasing the positive concept from +# positive: "person with short hair" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "girl" # what word for erasing the positive concept from +# positive: "baby girl" # concept to erase +# unconditional: "girl" # word to take the difference from the positive concept +# neutral: "girl" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: -4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "boy" # what word for erasing the positive concept from +# positive: "old man" # concept to erase +# unconditional: "boy" # word to take the difference from the positive concept +# neutral: "boy" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "boy" # what word for erasing the positive concept from +# positive: "baby boy" # concept to erase +# unconditional: "boy" # word to take the difference from the positive concept +# neutral: "boy" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: -4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 \ No newline at end of file diff --git a/trainscripts/imagesliders/debug_util.py b/trainscripts/imagesliders/debug_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5b115093b0ad6625db2f9aedaaa90924cf7de468 --- /dev/null +++ b/trainscripts/imagesliders/debug_util.py @@ -0,0 +1,16 @@ +# デバッグ用... + +import torch + + +def check_requires_grad(model: torch.nn.Module): + for name, module in list(model.named_modules())[:5]: + if len(list(module.parameters())) > 0: + print(f"Module: {name}") + for name, param in list(module.named_parameters())[:2]: + print(f" Parameter: {name}, Requires Grad: {param.requires_grad}") + + +def check_training_mode(model: torch.nn.Module): + for name, module in list(model.named_modules())[:5]: + print(f"Module: {name}, Training Mode: {module.training}") diff --git a/trainscripts/imagesliders/lora.py b/trainscripts/imagesliders/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..1762bb8ba32ba7b1c8cd84dc63963fa0704abf98 --- /dev/null +++ b/trainscripts/imagesliders/lora.py @@ -0,0 +1,256 @@ +# ref: +# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py + +import os +import math +from typing import Optional, List, Type, Set, Literal + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from safetensors.torch import save_file + + +UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ +# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 + "Attention" +] +UNET_TARGET_REPLACE_MODULE_CONV = [ + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", + # "DownBlock2D", + # "UpBlock2D" +] # locon, 3clier + +LORA_PREFIX_UNET = "lora_unet" + +DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER + +TRAINING_METHODS = Literal[ + "noxattn", # train all layers except x-attns and time_embed layers + "innoxattn", # train all layers except self attention layers + "selfattn", # ESD-u, train only self attention layers + "xattn", # ESD-x, train only x attention layers + "full", # train all layers + "xattn-strict", # q and k values + "noxattn-hspace", + "noxattn-hspace-last", + # "xlayer", + # "outxattn", + # "outsattn", + # "inxattn", + # "inmidsattn", + # "selflayer", +] + + +class LoRAModule(nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + + if "Linear" in org_module.__class__.__name__: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) + + elif "Conv" in org_module.__class__.__name__: # 一応 + in_dim = org_module.in_channels + out_dim = org_module.out_channels + + self.lora_dim = min(self.lora_dim, in_dim, out_dim) + if self.lora_dim != lora_dim: + print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = nn.Conv2d( + in_dim, self.lora_dim, kernel_size, stride, padding, bias=False + ) + self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + return ( + self.org_forward(x) + + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + ) + + +class LoRANetwork(nn.Module): + def __init__( + self, + unet: UNet2DConditionModel, + rank: int = 4, + multiplier: float = 1.0, + alpha: float = 1.0, + train_method: TRAINING_METHODS = "full", + ) -> None: + super().__init__() + self.lora_scale = 1 + self.multiplier = multiplier + self.lora_dim = rank + self.alpha = alpha + + # LoRAのみ + self.module = LoRAModule + + # unetのloraを作る + self.unet_loras = self.create_modules( + LORA_PREFIX_UNET, + unet, + DEFAULT_TARGET_REPLACE, + self.lora_dim, + self.multiplier, + train_method=train_method, + ) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + # assertion 名前の被りがないか確認しているようだ + lora_names = set() + for lora in self.unet_loras: + assert ( + lora.lora_name not in lora_names + ), f"duplicated lora name: {lora.lora_name}. {lora_names}" + lora_names.add(lora.lora_name) + + # 適用する + for lora in self.unet_loras: + lora.apply_to() + self.add_module( + lora.lora_name, + lora, + ) + + del unet + + torch.cuda.empty_cache() + + def create_modules( + self, + prefix: str, + root_module: nn.Module, + target_replace_modules: List[str], + rank: int, + multiplier: float, + train_method: TRAINING_METHODS, + ) -> list: + loras = [] + names = [] + for name, module in root_module.named_modules(): + if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習 + if "attn2" in name or "time_embed" in name: + continue + elif train_method == "innoxattn": # Cross Attention 以外学習 + if "attn2" in name: + continue + elif train_method == "selfattn": # Self Attention のみ学習 + if "attn1" not in name: + continue + elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習 + if "attn2" not in name: + continue + elif train_method == "full": # 全部学習 + pass + else: + raise NotImplementedError( + f"train_method: {train_method} is not implemented." + ) + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: + if train_method == 'xattn-strict': + if 'out' in child_name: + continue + if train_method == 'noxattn-hspace': + if 'mid_block' not in name: + continue + if train_method == 'noxattn-hspace-last': + if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: + continue + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") +# print(f"{lora_name}") + lora = self.module( + lora_name, child_module, multiplier, rank, self.alpha + ) +# print(name, child_name) +# print(child_module.weight.shape) + loras.append(lora) + names.append(lora_name) +# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}') + return loras + + def prepare_optimizer_params(self): + all_params = [] + + if self.unet_loras: # 実質これしかない + params = [] + [params.extend(lora.parameters()) for lora in self.unet_loras] + param_data = {"params": params} + all_params.append(param_data) + + return all_params + + def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + +# for key in list(state_dict.keys()): +# if not key.startswith("lora"): +# # lora以外除外 +# del state_dict[key] + + if os.path.splitext(file)[1] == ".safetensors": + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + def set_lora_slider(self, scale): + self.lora_scale = scale + + def __enter__(self): + for lora in self.unet_loras: + lora.multiplier = 1.0 * self.lora_scale + + def __exit__(self, exc_type, exc_value, tb): + for lora in self.unet_loras: + lora.multiplier = 0 diff --git a/trainscripts/imagesliders/model_util.py b/trainscripts/imagesliders/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..44d9d74c1d68b1e3366a6fed672de9e121225041 --- /dev/null +++ b/trainscripts/imagesliders/model_util.py @@ -0,0 +1,283 @@ +from typing import Literal, Union, Optional + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from diffusers import ( + UNet2DConditionModel, + SchedulerMixin, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + AutoencoderKL, +) +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + LMSDiscreteScheduler, + EulerAncestralDiscreteScheduler, +) + + +TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" +TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" + +AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] + +SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] + +DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this + + +def load_diffusers_model( + pretrained_model_name_or_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + # VAE はいらない + + if v2: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V2_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + # default is clip skip 2 + num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + else: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V1_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + return tokenizer, text_encoder, unet, vae + + +def load_checkpoint_model( + checkpoint_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + pipe = StableDiffusionPipeline.from_ckpt( + checkpoint_path, + upcast_attention=True if v2 else False, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + vae = pipe.vae + if clip_skip is not None: + if v2: + text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) + else: + text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) + + del pipe + + return tokenizer, text_encoder, unet, vae + + +def load_models( + pretrained_model_name_or_path: str, + scheduler_name: AVAILABLE_SCHEDULERS, + v2: bool = False, + v_pred: bool = False, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + tokenizer, text_encoder, unet, vae = load_checkpoint_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + else: # diffusers + tokenizer, text_encoder, unet, vae = load_diffusers_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + + # VAE はいらない + + scheduler = create_noise_scheduler( + scheduler_name, + prediction_type="v_prediction" if v_pred else "epsilon", + ) + + return tokenizer, text_encoder, unet, scheduler, vae + + +def load_diffusers_model_xl( + pretrained_model_name_or_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet + + tokenizers = [ + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + pad_token_id=0, # same as open clip + ), + ] + + text_encoders = [ + CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTextModelWithProjection.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + ] + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + return tokenizers, text_encoders, unet, vae + + +def load_checkpoint_model_xl( + checkpoint_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + pipe = StableDiffusionXLPipeline.from_single_file( + checkpoint_path, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + if len(text_encoders) == 2: + text_encoders[1].pad_token_id = 0 + + del pipe + + return tokenizers, text_encoders, unet + + +def load_models_xl( + pretrained_model_name_or_path: str, + scheduler_name: AVAILABLE_SCHEDULERS, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[ + list[CLIPTokenizer], + list[SDXL_TEXT_ENCODER_TYPE], + UNet2DConditionModel, + SchedulerMixin, +]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + ( + tokenizers, + text_encoders, + unet, + ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) + else: # diffusers + ( + tokenizers, + text_encoders, + unet, + vae + ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) + + scheduler = create_noise_scheduler(scheduler_name) + + return tokenizers, text_encoders, unet, scheduler, vae + + +def create_noise_scheduler( + scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", +) -> SchedulerMixin: + # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。 + + name = scheduler_name.lower().replace(" ", "_") + if name == "ddim": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + prediction_type=prediction_type, # これでいいの? + ) + elif name == "ddpm": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm + scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + prediction_type=prediction_type, + ) + elif name == "lms": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete + scheduler = LMSDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + prediction_type=prediction_type, + ) + elif name == "euler_a": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral + scheduler = EulerAncestralDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + prediction_type=prediction_type, + ) + else: + raise ValueError(f"Unknown scheduler name: {name}") + + return scheduler diff --git a/trainscripts/imagesliders/prompt_util.py b/trainscripts/imagesliders/prompt_util.py new file mode 100644 index 0000000000000000000000000000000000000000..27de3eaa42044eba09b0917e694b441fb72a4ef7 --- /dev/null +++ b/trainscripts/imagesliders/prompt_util.py @@ -0,0 +1,174 @@ +from typing import Literal, Optional, Union, List + +import yaml +from pathlib import Path + + +from pydantic import BaseModel, root_validator +import torch +import copy + +ACTION_TYPES = Literal[ + "erase", + "enhance", +] + + +# XL は二種類必要なので +class PromptEmbedsXL: + text_embeds: torch.FloatTensor + pooled_embeds: torch.FloatTensor + + def __init__(self, *args) -> None: + self.text_embeds = args[0] + self.pooled_embeds = args[1] + + +# SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL +PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL] + + +class PromptEmbedsCache: # 使いまわしたいので + prompts: dict[str, PROMPT_EMBEDDING] = {} + + def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class PromptSettings(BaseModel): # yaml のやつ + target: str + positive: str = None # if None, target will be used + unconditional: str = "" # default is "" + neutral: str = None # if None, unconditional will be used + action: ACTION_TYPES = "erase" # default is "erase" + guidance_scale: float = 1.0 # default is 1.0 + resolution: int = 512 # default is 512 + dynamic_resolution: bool = False # default is False + batch_size: int = 1 # default is 1 + dynamic_crops: bool = False # default is False. only used when model is XL + + @root_validator(pre=True) + def fill_prompts(cls, values): + keys = values.keys() + if "target" not in keys: + raise ValueError("target must be specified") + if "positive" not in keys: + values["positive"] = values["target"] + if "unconditional" not in keys: + values["unconditional"] = "" + if "neutral" not in keys: + values["neutral"] = values["unconditional"] + + return values + + +class PromptEmbedsPair: + target: PROMPT_EMBEDDING # not want to generate the concept + positive: PROMPT_EMBEDDING # generate the concept + unconditional: PROMPT_EMBEDDING # uncondition (default should be empty) + neutral: PROMPT_EMBEDDING # base condition (default should be empty) + + guidance_scale: float + resolution: int + dynamic_resolution: bool + batch_size: int + dynamic_crops: bool + + loss_fn: torch.nn.Module + action: ACTION_TYPES + + def __init__( + self, + loss_fn: torch.nn.Module, + target: PROMPT_EMBEDDING, + positive: PROMPT_EMBEDDING, + unconditional: PROMPT_EMBEDDING, + neutral: PROMPT_EMBEDDING, + settings: PromptSettings, + ) -> None: + self.loss_fn = loss_fn + self.target = target + self.positive = positive + self.unconditional = unconditional + self.neutral = neutral + + self.guidance_scale = settings.guidance_scale + self.resolution = settings.resolution + self.dynamic_resolution = settings.dynamic_resolution + self.batch_size = settings.batch_size + self.dynamic_crops = settings.dynamic_crops + self.action = settings.action + + def _erase( + self, + target_latents: torch.FloatTensor, # "van gogh" + positive_latents: torch.FloatTensor, # "van gogh" + unconditional_latents: torch.FloatTensor, # "" + neutral_latents: torch.FloatTensor, # "" + ) -> torch.FloatTensor: + """Target latents are going not to have the positive concept.""" + return self.loss_fn( + target_latents, + neutral_latents + - self.guidance_scale * (positive_latents - unconditional_latents) + ) + + + def _enhance( + self, + target_latents: torch.FloatTensor, # "van gogh" + positive_latents: torch.FloatTensor, # "van gogh" + unconditional_latents: torch.FloatTensor, # "" + neutral_latents: torch.FloatTensor, # "" + ): + """Target latents are going to have the positive concept.""" + return self.loss_fn( + target_latents, + neutral_latents + + self.guidance_scale * (positive_latents - unconditional_latents) + ) + + def loss( + self, + **kwargs, + ): + if self.action == "erase": + return self._erase(**kwargs) + + elif self.action == "enhance": + return self._enhance(**kwargs) + + else: + raise ValueError("action must be erase or enhance") + + +def load_prompts_from_yaml(path, attributes = []): + with open(path, "r") as f: + prompts = yaml.safe_load(f) + print(prompts) + if len(prompts) == 0: + raise ValueError("prompts file is empty") + if len(attributes)!=0: + newprompts = [] + for i in range(len(prompts)): + for att in attributes: + copy_ = copy.deepcopy(prompts[i]) + copy_['target'] = att + ' ' + copy_['target'] + copy_['positive'] = att + ' ' + copy_['positive'] + copy_['neutral'] = att + ' ' + copy_['neutral'] + copy_['unconditional'] = att + ' ' + copy_['unconditional'] + newprompts.append(copy_) + else: + newprompts = copy.deepcopy(prompts) + + print(newprompts) + print(len(prompts), len(newprompts)) + prompt_settings = [PromptSettings(**prompt) for prompt in newprompts] + + return prompt_settings diff --git a/trainscripts/imagesliders/train_lora-scale-xl.py b/trainscripts/imagesliders/train_lora-scale-xl.py new file mode 100644 index 0000000000000000000000000000000000000000..07bed82d2698dae49ea9a5e2ccc174288a84d9ab --- /dev/null +++ b/trainscripts/imagesliders/train_lora-scale-xl.py @@ -0,0 +1,548 @@ +# ref: +# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 +# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py + +from typing import List, Optional +import argparse +import ast +from pathlib import Path +import gc, os +import numpy as np + +import torch +from tqdm import tqdm +from PIL import Image + + + +import train_util +import random +import model_util +import prompt_util +from prompt_util import ( + PromptEmbedsCache, + PromptEmbedsPair, + PromptSettings, + PromptEmbedsXL, +) +import debug_util +import config_util +from config_util import RootConfig + +import wandb + +NUM_IMAGES_PER_PROMPT = 1 +from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +def train( + config: RootConfig, + prompts: list[PromptSettings], + device, + folder_main: str, + folders, + scales, + +): + scales = np.array(scales) + folders = np.array(folders) + scales_unique = list(scales) + + metadata = { + "prompts": ",".join([prompt.json() for prompt in prompts]), + "config": config.json(), + } + save_path = Path(config.save.path) + + modules = DEFAULT_TARGET_REPLACE + if config.network.type == "c3lier": + modules += UNET_TARGET_REPLACE_MODULE_CONV + + if config.logging.verbose: + print(metadata) + + if config.logging.use_wandb: + wandb.init(project=f"LECO_{config.save.name}", config=metadata) + + weight_dtype = config_util.parse_precision(config.train.precision) + save_weight_dtype = config_util.parse_precision(config.train.precision) + + ( + tokenizers, + text_encoders, + unet, + noise_scheduler, + vae + ) = model_util.load_models_xl( + config.pretrained_model.name_or_path, + scheduler_name=config.train.noise_scheduler, + ) + + for text_encoder in text_encoders: + text_encoder.to(device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + unet.to(device, dtype=weight_dtype) + if config.other.use_xformers: + unet.enable_xformers_memory_efficient_attention() + unet.requires_grad_(False) + unet.eval() + + vae.to(device) + vae.requires_grad_(False) + vae.eval() + + network = LoRANetwork( + unet, + rank=config.network.rank, + multiplier=1.0, + alpha=config.network.alpha, + train_method=config.network.training_method, + ).to(device, dtype=weight_dtype) + + optimizer_module = train_util.get_optimizer(config.train.optimizer) + #optimizer_args + optimizer_kwargs = {} + if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: + for arg in config.train.optimizer_args.split(" "): + key, value = arg.split("=") + value = ast.literal_eval(value) + optimizer_kwargs[key] = value + + optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) + lr_scheduler = train_util.get_lr_scheduler( + config.train.lr_scheduler, + optimizer, + max_iterations=config.train.iterations, + lr_min=config.train.lr / 100, + ) + criteria = torch.nn.MSELoss() + + print("Prompts") + for settings in prompts: + print(settings) + + # debug + debug_util.check_requires_grad(network) + debug_util.check_training_mode(network) + + cache = PromptEmbedsCache() + prompt_pairs: list[PromptEmbedsPair] = [] + + with torch.no_grad(): + for settings in prompts: + print(settings) + for prompt in [ + settings.target, + settings.positive, + settings.neutral, + settings.unconditional, + ]: + if cache[prompt] == None: + tex_embs, pool_embs = train_util.encode_prompts_xl( + tokenizers, + text_encoders, + [prompt], + num_images_per_prompt=NUM_IMAGES_PER_PROMPT, + ) + cache[prompt] = PromptEmbedsXL( + tex_embs, + pool_embs + ) + + prompt_pairs.append( + PromptEmbedsPair( + criteria, + cache[settings.target], + cache[settings.positive], + cache[settings.unconditional], + cache[settings.neutral], + settings, + ) + ) + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + del tokenizer, text_encoder + + flush() + + pbar = tqdm(range(config.train.iterations)) + + loss = None + + for i in pbar: + with torch.no_grad(): + noise_scheduler.set_timesteps( + config.train.max_denoising_steps, device=device + ) + + optimizer.zero_grad() + + prompt_pair: PromptEmbedsPair = prompt_pairs[ + torch.randint(0, len(prompt_pairs), (1,)).item() + ] + + # 1 ~ 49 からランダム + timesteps_to = torch.randint( + 1, config.train.max_denoising_steps, (1,) + ).item() + + height, width = prompt_pair.resolution, prompt_pair.resolution + if prompt_pair.dynamic_resolution: + height, width = train_util.get_random_resolution_in_bucket( + prompt_pair.resolution + ) + + if config.logging.verbose: + print("guidance_scale:", prompt_pair.guidance_scale) + print("resolution:", prompt_pair.resolution) + print("dynamic_resolution:", prompt_pair.dynamic_resolution) + if prompt_pair.dynamic_resolution: + print("bucketed resolution:", (height, width)) + print("batch_size:", prompt_pair.batch_size) + print("dynamic_crops:", prompt_pair.dynamic_crops) + + + + scale_to_look = abs(random.choice(list(scales_unique))) + folder1 = folders[scales==-scale_to_look][0] + folder2 = folders[scales==scale_to_look][0] + + ims = os.listdir(f'{folder_main}/{folder1}/') + ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_] + random_sampler = random.randint(0, len(ims)-1) + + img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((512,512)) + img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((512,512)) + + seed = random.randint(0,2*15) + + generator = torch.manual_seed(seed) + denoised_latents_low, low_noise = train_util.get_noisy_image( + img1, + vae, + generator, + unet, + noise_scheduler, + start_timesteps=0, + total_timesteps=timesteps_to) + denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype) + low_noise = low_noise.to(device, dtype=weight_dtype) + + generator = torch.manual_seed(seed) + denoised_latents_high, high_noise = train_util.get_noisy_image( + img2, + vae, + generator, + unet, + noise_scheduler, + start_timesteps=0, + total_timesteps=timesteps_to) + denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype) + high_noise = high_noise.to(device, dtype=weight_dtype) + noise_scheduler.set_timesteps(1000) + + add_time_ids = train_util.get_add_time_ids( + height, + width, + dynamic_crops=prompt_pair.dynamic_crops, + dtype=weight_dtype, + ).to(device, dtype=weight_dtype) + + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / config.train.max_denoising_steps) + ] + try: + # with network: の外では空のLoRAのみが有効になる + high_latents = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents_high, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.positive.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.positive.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=torch.float32) + except: + flush() + print(f'Error Occured!: {np.array(img1).shape} {np.array(img2).shape}') + continue + # with network: の外では空のLoRAのみが有効になる + + low_latents = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents_low, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.neutral.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.neutral.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=torch.float32) + + + + if config.logging.verbose: + print("positive_latents:", positive_latents[0, 0, :5, :5]) + print("neutral_latents:", neutral_latents[0, 0, :5, :5]) + print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) + + network.set_lora_slider(scale=scale_to_look) + with network: + target_latents_high = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents_high, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.positive.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.positive.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=torch.float32) + + high_latents.requires_grad = False + low_latents.requires_grad = False + + loss_high = criteria(target_latents_high, high_noise.to(torch.float32)) + pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}") + loss_high.backward() + + # opposite + network.set_lora_slider(scale=-scale_to_look) + with network: + target_latents_low = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents_low, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.neutral.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.neutral.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=torch.float32) + + + high_latents.requires_grad = False + low_latents.requires_grad = False + + loss_low = criteria(target_latents_low, low_noise.to(torch.float32)) + pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}") + loss_low.backward() + + + optimizer.step() + lr_scheduler.step() + + del ( + high_latents, + low_latents, + target_latents_low, + target_latents_high, + ) + flush() + + if ( + i % config.save.per_steps == 0 + and i != 0 + and i != config.train.iterations - 1 + ): + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_{i}steps.pt", + dtype=save_weight_dtype, + ) + + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_last.pt", + dtype=save_weight_dtype, + ) + + del ( + unet, + noise_scheduler, + loss, + optimizer, + network, + ) + + flush() + + print("Done.") + + +def main(args): + config_file = args.config_file + + config = config_util.load_config_from_yaml(config_file) + if args.name is not None: + config.save.name = args.name + attributes = [] + if args.attributes is not None: + attributes = args.attributes.split(',') + attributes = [a.strip() for a in attributes] + + config.network.alpha = args.alpha + config.network.rank = args.rank + config.save.name += f'_alpha{args.alpha}' + config.save.name += f'_rank{config.network.rank }' + config.save.name += f'_{config.network.training_method}' + config.save.path += f'/{config.save.name}' + + prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes) + + device = torch.device(f"cuda:{args.device}") + + folders = args.folders.split(',') + folders = [f.strip() for f in folders] + scales = args.scales.split(',') + scales = [f.strip() for f in scales] + scales = [int(s) for s in scales] + + print(folders, scales) + if len(scales) != len(folders): + raise Exception('the number of folders need to match the number of scales') + + if args.stylecheck is not None: + check = args.stylecheck.split('-') + + for i in range(int(check[0]), int(check[1])): + folder_main = args.folder_main+ f'{i}' + config.save.name = f'{os.path.basename(folder_main)}' + config.save.name += f'_alpha{args.alpha}' + config.save.name += f'_rank{config.network.rank }' + config.save.path = f'models/{config.save.name}' + train(config=config, prompts=prompts, device=device, folder_main = folder_main, folders = folders, scales = scales) + else: + train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_file", + required=True, + help="Config file for training.", + ) + # config_file 'data/config.yaml' + parser.add_argument( + "--alpha", + type=float, + required=True, + help="LoRA weight.", + ) + # --alpha 1.0 + parser.add_argument( + "--rank", + type=int, + required=False, + help="Rank of LoRA.", + default=4, + ) + # --rank 4 + parser.add_argument( + "--device", + type=int, + required=False, + default=0, + help="Device to train on.", + ) + # --device 0 + parser.add_argument( + "--name", + type=str, + required=False, + default=None, + help="Device to train on.", + ) + # --name 'eyesize_slider' + parser.add_argument( + "--attributes", + type=str, + required=False, + default=None, + help="attritbutes to disentangle (comma seperated string)", + ) + parser.add_argument( + "--folder_main", + type=str, + required=True, + help="The folder to check", + ) + + parser.add_argument( + "--stylecheck", + type=str, + required=False, + default = None, + help="The folder to check", + ) + + parser.add_argument( + "--folders", + type=str, + required=False, + default = 'verylow, low, high, veryhigh', + help="folders with different attribute-scaled images", + ) + parser.add_argument( + "--scales", + type=str, + required=False, + default = '-2, -1, 1, 2', + help="scales for different attribute-scaled images", + ) + + + args = parser.parse_args() + + main(args) diff --git a/trainscripts/imagesliders/train_lora-scale.py b/trainscripts/imagesliders/train_lora-scale.py new file mode 100644 index 0000000000000000000000000000000000000000..62c3a722e5a3958436971b0ed0d4e16aaab02b06 --- /dev/null +++ b/trainscripts/imagesliders/train_lora-scale.py @@ -0,0 +1,501 @@ +# ref: +# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 +# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py + +from typing import List, Optional +import argparse +import ast +from pathlib import Path +import gc + +import torch +from tqdm import tqdm +import os, glob + +from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV +import train_util +import model_util +import prompt_util +from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings +import debug_util +import config_util +from config_util import RootConfig +import random +import numpy as np +import wandb +from PIL import Image + +def flush(): + torch.cuda.empty_cache() + gc.collect() +def prev_step(model_output, timestep, scheduler, sample): + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t =scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output + prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction + return prev_sample + +def train( + config: RootConfig, + prompts: list[PromptSettings], + device: int, + folder_main: str, + folders, + scales, +): + scales = np.array(scales) + folders = np.array(folders) + scales_unique = list(scales) + + metadata = { + "prompts": ",".join([prompt.json() for prompt in prompts]), + "config": config.json(), + } + save_path = Path(config.save.path) + + modules = DEFAULT_TARGET_REPLACE + if config.network.type == "c3lier": + modules += UNET_TARGET_REPLACE_MODULE_CONV + + if config.logging.verbose: + print(metadata) + + if config.logging.use_wandb: + wandb.init(project=f"LECO_{config.save.name}", config=metadata) + + weight_dtype = config_util.parse_precision(config.train.precision) + save_weight_dtype = config_util.parse_precision(config.train.precision) + + tokenizer, text_encoder, unet, noise_scheduler, vae = model_util.load_models( + config.pretrained_model.name_or_path, + scheduler_name=config.train.noise_scheduler, + v2=config.pretrained_model.v2, + v_pred=config.pretrained_model.v_pred, + ) + + text_encoder.to(device, dtype=weight_dtype) + text_encoder.eval() + + unet.to(device, dtype=weight_dtype) + unet.enable_xformers_memory_efficient_attention() + unet.requires_grad_(False) + unet.eval() + + vae.to(device) + vae.requires_grad_(False) + vae.eval() + + network = LoRANetwork( + unet, + rank=config.network.rank, + multiplier=1.0, + alpha=config.network.alpha, + train_method=config.network.training_method, + ).to(device, dtype=weight_dtype) + + optimizer_module = train_util.get_optimizer(config.train.optimizer) + #optimizer_args + optimizer_kwargs = {} + if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: + for arg in config.train.optimizer_args.split(" "): + key, value = arg.split("=") + value = ast.literal_eval(value) + optimizer_kwargs[key] = value + + optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) + lr_scheduler = train_util.get_lr_scheduler( + config.train.lr_scheduler, + optimizer, + max_iterations=config.train.iterations, + lr_min=config.train.lr / 100, + ) + criteria = torch.nn.MSELoss() + + print("Prompts") + for settings in prompts: + print(settings) + + # debug + debug_util.check_requires_grad(network) + debug_util.check_training_mode(network) + + cache = PromptEmbedsCache() + prompt_pairs: list[PromptEmbedsPair] = [] + + with torch.no_grad(): + for settings in prompts: + print(settings) + for prompt in [ + settings.target, + settings.positive, + settings.neutral, + settings.unconditional, + ]: + print(prompt) + if isinstance(prompt, list): + if prompt == settings.positive: + key_setting = 'positive' + else: + key_setting = 'attributes' + if len(prompt) == 0: + cache[key_setting] = [] + else: + if cache[key_setting] is None: + cache[key_setting] = train_util.encode_prompts( + tokenizer, text_encoder, prompt + ) + else: + if cache[prompt] == None: + cache[prompt] = train_util.encode_prompts( + tokenizer, text_encoder, [prompt] + ) + + prompt_pairs.append( + PromptEmbedsPair( + criteria, + cache[settings.target], + cache[settings.positive], + cache[settings.unconditional], + cache[settings.neutral], + settings, + ) + ) + + del tokenizer + del text_encoder + + flush() + + pbar = tqdm(range(config.train.iterations)) + for i in pbar: + with torch.no_grad(): + noise_scheduler.set_timesteps( + config.train.max_denoising_steps, device=device + ) + + optimizer.zero_grad() + + prompt_pair: PromptEmbedsPair = prompt_pairs[ + torch.randint(0, len(prompt_pairs), (1,)).item() + ] + + # 1 ~ 49 からランダム + timesteps_to = torch.randint( + 1, config.train.max_denoising_steps-1, (1,) +# 1, 25, (1,) + ).item() + + height, width = ( + prompt_pair.resolution, + prompt_pair.resolution, + ) + if prompt_pair.dynamic_resolution: + height, width = train_util.get_random_resolution_in_bucket( + prompt_pair.resolution + ) + + if config.logging.verbose: + print("guidance_scale:", prompt_pair.guidance_scale) + print("resolution:", prompt_pair.resolution) + print("dynamic_resolution:", prompt_pair.dynamic_resolution) + if prompt_pair.dynamic_resolution: + print("bucketed resolution:", (height, width)) + print("batch_size:", prompt_pair.batch_size) + + + + + scale_to_look = abs(random.choice(list(scales_unique))) + folder1 = folders[scales==-scale_to_look][0] + folder2 = folders[scales==scale_to_look][0] + + ims = os.listdir(f'{folder_main}/{folder1}/') + ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_] + random_sampler = random.randint(0, len(ims)-1) + + img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((256,256)) + img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((256,256)) + + seed = random.randint(0,2*15) + + generator = torch.manual_seed(seed) + denoised_latents_low, low_noise = train_util.get_noisy_image( + img1, + vae, + generator, + unet, + noise_scheduler, + start_timesteps=0, + total_timesteps=timesteps_to) + denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype) + low_noise = low_noise.to(device, dtype=weight_dtype) + + generator = torch.manual_seed(seed) + denoised_latents_high, high_noise = train_util.get_noisy_image( + img2, + vae, + generator, + unet, + noise_scheduler, + start_timesteps=0, + total_timesteps=timesteps_to) + denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype) + high_noise = high_noise.to(device, dtype=weight_dtype) + noise_scheduler.set_timesteps(1000) + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / config.train.max_denoising_steps) + ] + + # with network: の外では空のLoRAのみが有効になる + high_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents_high, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.positive, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + # with network: の外では空のLoRAのみが有効になる + low_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents_low, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.unconditional, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + if config.logging.verbose: + print("positive_latents:", positive_latents[0, 0, :5, :5]) + print("neutral_latents:", neutral_latents[0, 0, :5, :5]) + print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) + + network.set_lora_slider(scale=scale_to_look) + with network: + target_latents_high = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents_high, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.positive, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + + + high_latents.requires_grad = False + low_latents.requires_grad = False + + loss_high = criteria(target_latents_high, high_noise.cpu().to(torch.float32)) + pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}") + loss_high.backward() + + + network.set_lora_slider(scale=-scale_to_look) + with network: + target_latents_low = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents_low, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.neutral, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to("cpu", dtype=torch.float32) + + + high_latents.requires_grad = False + low_latents.requires_grad = False + + loss_low = criteria(target_latents_low, low_noise.cpu().to(torch.float32)) + pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}") + loss_low.backward() + + ## NOTICE NO zero_grad between these steps (accumulating gradients) + #following guidelines from Ostris (https://github.com/ostris/ai-toolkit) + + optimizer.step() + lr_scheduler.step() + + del ( + high_latents, + low_latents, + target_latents_low, + target_latents_high, + ) + flush() + + if ( + i % config.save.per_steps == 0 + and i != 0 + and i != config.train.iterations - 1 + ): + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_{i}steps.pt", + dtype=save_weight_dtype, + ) + + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_last.pt", + dtype=save_weight_dtype, + ) + + del ( + unet, + noise_scheduler, + optimizer, + network, + ) + + flush() + + print("Done.") + + +def main(args): + config_file = args.config_file + + config = config_util.load_config_from_yaml(config_file) + if args.name is not None: + config.save.name = args.name + attributes = [] + if args.attributes is not None: + attributes = args.attributes.split(',') + attributes = [a.strip() for a in attributes] + + config.network.alpha = args.alpha + config.network.rank = args.rank + config.save.name += f'_alpha{args.alpha}' + config.save.name += f'_rank{config.network.rank }' + config.save.name += f'_{config.network.training_method}' + config.save.path += f'/{config.save.name}' + + prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes) + device = torch.device(f"cuda:{args.device}") + + + folders = args.folders.split(',') + folders = [f.strip() for f in folders] + scales = args.scales.split(',') + scales = [f.strip() for f in scales] + scales = [int(s) for s in scales] + + print(folders, scales) + if len(scales) != len(folders): + raise Exception('the number of folders need to match the number of scales') + + if args.stylecheck is not None: + check = args.stylecheck.split('-') + + for i in range(int(check[0]), int(check[1])): + folder_main = args.folder_main+ f'{i}' + config.save.name = f'{os.path.basename(folder_main)}' + config.save.name += f'_alpha{args.alpha}' + config.save.name += f'_rank{config.network.rank }' + config.save.path = f'models/{config.save.name}' + train(config=config, prompts=prompts, device=device, folder_main = folder_main) + else: + train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_file", + required=False, + default = 'data/config.yaml', + help="Config file for training.", + ) + parser.add_argument( + "--alpha", + type=float, + required=True, + help="LoRA weight.", + ) + + parser.add_argument( + "--rank", + type=int, + required=False, + help="Rank of LoRA.", + default=4, + ) + + parser.add_argument( + "--device", + type=int, + required=False, + default=0, + help="Device to train on.", + ) + + parser.add_argument( + "--name", + type=str, + required=False, + default=None, + help="Device to train on.", + ) + + parser.add_argument( + "--attributes", + type=str, + required=False, + default=None, + help="attritbutes to disentangle", + ) + + parser.add_argument( + "--folder_main", + type=str, + required=True, + help="The folder to check", + ) + + parser.add_argument( + "--stylecheck", + type=str, + required=False, + default = None, + help="The folder to check", + ) + + parser.add_argument( + "--folders", + type=str, + required=False, + default = 'verylow, low, high, veryhigh', + help="folders with different attribute-scaled images", + ) + parser.add_argument( + "--scales", + type=str, + required=False, + default = '-2, -1,1, 2', + help="scales for different attribute-scaled images", + ) + + + args = parser.parse_args() + + main(args) diff --git a/trainscripts/imagesliders/train_util.py b/trainscripts/imagesliders/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..36618d91eaae074388d2362604b5a9396f5656ee --- /dev/null +++ b/trainscripts/imagesliders/train_util.py @@ -0,0 +1,458 @@ +from typing import Optional, Union + +import torch + +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import UNet2DConditionModel, SchedulerMixin +from diffusers.image_processor import VaeImageProcessor +from model_util import SDXL_TEXT_ENCODER_TYPE +from diffusers.utils import randn_tensor + +from tqdm import tqdm + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + +UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL +TEXT_ENCODER_2_PROJECTION_DIM = 1280 +UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 + + +def get_random_noise( + batch_size: int, height: int, width: int, generator: torch.Generator = None +) -> torch.Tensor: + return torch.randn( + ( + batch_size, + UNET_IN_CHANNELS, + height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや + width // VAE_SCALE_FACTOR, + ), + generator=generator, + device="cpu", + ) + + +# https://www.crosslabs.org/blog/diffusion-with-offset-noise +def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float): + latents = latents + noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + return latents + + +def get_initial_latents( + scheduler: SchedulerMixin, + n_imgs: int, + height: int, + width: int, + n_prompts: int, + generator=None, +) -> torch.Tensor: + noise = get_random_noise(n_imgs, height, width, generator=generator).repeat( + n_prompts, 1, 1, 1 + ) + + latents = noise * scheduler.init_noise_sigma + + return latents + + +def text_tokenize( + tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ! + prompts: list[str], +): + return tokenizer( + prompts, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + +def text_encode(text_encoder: CLIPTextModel, tokens): + return text_encoder(tokens.to(text_encoder.device))[0] + + +def encode_prompts( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTokenizer, + prompts: list[str], +): + + text_tokens = text_tokenize(tokenizer, prompts) + text_embeddings = text_encode(text_encoder, text_tokens) + + + + return text_embeddings + + +# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 +def text_encode_xl( + text_encoder: SDXL_TEXT_ENCODER_TYPE, + tokens: torch.FloatTensor, + num_images_per_prompt: int = 1, +): + prompt_embeds = text_encoder( + tokens.to(text_encoder.device), output_hidden_states=True + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, pooled_prompt_embeds + + +def encode_prompts_xl( + tokenizers: list[CLIPTokenizer], + text_encoders: list[SDXL_TEXT_ENCODER_TYPE], + prompts: list[str], + num_images_per_prompt: int = 1, +) -> tuple[torch.FloatTensor, torch.FloatTensor]: + # text_encoder and text_encoder_2's penuultimate layer's output + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_tokens_input_ids = text_tokenize(tokenizer, prompts) + text_embeds, pooled_text_embeds = text_encode_xl( + text_encoder, text_tokens_input_ids, num_images_per_prompt + ) + + text_embeds_list.append(text_embeds) + + bs_embed = pooled_text_embeds.shape[0] + pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds + + +def concat_embeddings( + unconditional: torch.FloatTensor, + conditional: torch.FloatTensor, + n_imgs: int, +): + return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) + + +# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721 +def predict_noise( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + timestep: int, # 現在のタイムステップ + latents: torch.FloatTensor, + text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの + guidance_scale=7.5, +) -> torch.FloatTensor: + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guided_target = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return guided_target + + + +# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 +@torch.no_grad() +def diffusion( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + latents: torch.FloatTensor, # ただのノイズだけのlatents + text_embeddings: torch.FloatTensor, + total_timesteps: int = 1000, + start_timesteps=0, + **kwargs, +): + # latents_steps = [] + + for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]): + noise_pred = predict_noise( + unet, scheduler, timestep, latents, text_embeddings, **kwargs + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + + # return latents_steps + return latents + +@torch.no_grad() +def get_noisy_image( + img, + vae, + generator, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + total_timesteps: int = 1000, + start_timesteps=0, + + **kwargs, +): + # latents_steps = [] + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + image = img + im_orig = image + device = vae.device + image = image_processor.preprocess(image).to(device) + + init_latents = vae.encode(image).latent_dist.sample(None) + init_latents = vae.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + + noise = randn_tensor(shape, generator=generator, device=device) + + time_ = total_timesteps + timestep = scheduler.timesteps[time_:time_+1] + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + return init_latents, noise + + +def rescale_noise_cfg( + noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0 +): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + + return noise_cfg + + +def predict_noise_xl( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + timestep: int, # 現在のタイムステップ + latents: torch.FloatTensor, + text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの + add_text_embeddings: torch.FloatTensor, # pooled なやつ + add_time_ids: torch.FloatTensor, + guidance_scale=7.5, + guidance_rescale=0.7, +) -> torch.FloatTensor: + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + added_cond_kwargs = { + "text_embeds": add_text_embeddings, + "time_ids": add_time_ids, + } + + # predict the noise residual + noise_pred = unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guided_target = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) + + return guided_target + + +@torch.no_grad() +def diffusion_xl( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + latents: torch.FloatTensor, # ただのノイズだけのlatents + text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor], + add_text_embeddings: torch.FloatTensor, # pooled なやつ + add_time_ids: torch.FloatTensor, + guidance_scale: float = 1.0, + total_timesteps: int = 1000, + start_timesteps=0, +): + # latents_steps = [] + + for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]): + noise_pred = predict_noise_xl( + unet, + scheduler, + timestep, + latents, + text_embeddings, + add_text_embeddings, + add_time_ids, + guidance_scale=guidance_scale, + guidance_rescale=0.7, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + + # return latents_steps + return latents + + +# for XL +def get_add_time_ids( + height: int, + width: int, + dynamic_crops: bool = False, + dtype: torch.dtype = torch.float32, +): + if dynamic_crops: + # random float scale between 1 and 3 + random_scale = torch.rand(1).item() * 2 + 1 + original_size = (int(height * random_scale), int(width * random_scale)) + # random position + crops_coords_top_left = ( + torch.randint(0, original_size[0] - height, (1,)).item(), + torch.randint(0, original_size[1] - width, (1,)).item(), + ) + target_size = (height, width) + else: + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + + # this is expected as 6 + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + # this is expected as 2816 + passed_add_embed_dim = ( + UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 + + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 + ) + if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: + raise ValueError( + f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + +def get_optimizer(name: str): + name = name.lower() + + if name.startswith("dadapt"): + import dadaptation + + if name == "dadaptadam": + return dadaptation.DAdaptAdam + elif name == "dadaptlion": + return dadaptation.DAdaptLion + else: + raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion") + + elif name.endswith("8bit"): # 検証してない + import bitsandbytes as bnb + + if name == "adam8bit": + return bnb.optim.Adam8bit + elif name == "lion8bit": + return bnb.optim.Lion8bit + else: + raise ValueError("8bit optimizer must be adam8bit or lion8bit") + + else: + if name == "adam": + return torch.optim.Adam + elif name == "adamw": + return torch.optim.AdamW + elif name == "lion": + from lion_pytorch import Lion + + return Lion + elif name == "prodigy": + import prodigyopt + + return prodigyopt.Prodigy + else: + raise ValueError("Optimizer must be adam, adamw, lion or Prodigy") + + +def get_lr_scheduler( + name: Optional[str], + optimizer: torch.optim.Optimizer, + max_iterations: Optional[int], + lr_min: Optional[float], + **kwargs, +): + if name == "cosine": + return torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs + ) + elif name == "cosine_with_restarts": + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs + ) + elif name == "step": + return torch.optim.lr_scheduler.StepLR( + optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs + ) + elif name == "constant": + return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs) + elif name == "linear": + return torch.optim.lr_scheduler.LinearLR( + optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs + ) + else: + raise ValueError( + "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" + ) + + +def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]: + max_resolution = bucket_resolution + min_resolution = bucket_resolution // 2 + + step = 64 + + min_step = min_resolution // step + max_step = max_resolution // step + + height = torch.randint(min_step, max_step, (1,)).item() * step + width = torch.randint(min_step, max_step, (1,)).item() * step + + return height, width diff --git a/trainscripts/textsliders/__init__.py b/trainscripts/textsliders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trainscripts/textsliders/config_util.py b/trainscripts/textsliders/config_util.py new file mode 100644 index 0000000000000000000000000000000000000000..25e184821e6db86b265a6671a16b72d91682c205 --- /dev/null +++ b/trainscripts/textsliders/config_util.py @@ -0,0 +1,104 @@ +from typing import Literal, Optional + +import yaml + +from pydantic import BaseModel +import torch + +from lora import TRAINING_METHODS + +PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"] +NETWORK_TYPES = Literal["lierla", "c3lier"] + + +class PretrainedModelConfig(BaseModel): + name_or_path: str + v2: bool = False + v_pred: bool = False + + clip_skip: Optional[int] = None + + +class NetworkConfig(BaseModel): + type: NETWORK_TYPES = "lierla" + rank: int = 4 + alpha: float = 1.0 + + training_method: TRAINING_METHODS = "full" + + +class TrainConfig(BaseModel): + precision: PRECISION_TYPES = "bfloat16" + noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim" + + iterations: int = 500 + lr: float = 1e-4 + optimizer: str = "adamw" + optimizer_args: str = "" + lr_scheduler: str = "constant" + + max_denoising_steps: int = 50 + + +class SaveConfig(BaseModel): + name: str = "untitled" + path: str = "./output" + per_steps: int = 200 + precision: PRECISION_TYPES = "float32" + + +class LoggingConfig(BaseModel): + use_wandb: bool = False + + verbose: bool = False + + +class OtherConfig(BaseModel): + use_xformers: bool = False + + +class RootConfig(BaseModel): + prompts_file: str + pretrained_model: PretrainedModelConfig + + network: NetworkConfig + + train: Optional[TrainConfig] + + save: Optional[SaveConfig] + + logging: Optional[LoggingConfig] + + other: Optional[OtherConfig] + + +def parse_precision(precision: str) -> torch.dtype: + if precision == "fp32" or precision == "float32": + return torch.float32 + elif precision == "fp16" or precision == "float16": + return torch.float16 + elif precision == "bf16" or precision == "bfloat16": + return torch.bfloat16 + + raise ValueError(f"Invalid precision type: {precision}") + + +def load_config_from_yaml(config_path: str) -> RootConfig: + with open(config_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + root = RootConfig(**config) + + if root.train is None: + root.train = TrainConfig() + + if root.save is None: + root.save = SaveConfig() + + if root.logging is None: + root.logging = LoggingConfig() + + if root.other is None: + root.other = OtherConfig() + + return root diff --git a/trainscripts/textsliders/data/config-xl.yaml b/trainscripts/textsliders/data/config-xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7cd655a0b214dcadf49d6c2391c1a92c758b8ccd --- /dev/null +++ b/trainscripts/textsliders/data/config-xl.yaml @@ -0,0 +1,28 @@ +prompts_file: "trainscripts/textsliders/data/prompts-xl.yaml" +pretrained_model: + name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models + v2: false # true if model is v2.x + v_pred: false # true if model uses v-prediction +network: + type: "c3lier" # or "c3lier" or "lierla" + rank: 4 + alpha: 1.0 + training_method: "noxattn" +train: + precision: "bfloat16" + noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" + iterations: 1000 + lr: 0.0002 + optimizer: "AdamW" + lr_scheduler: "constant" + max_denoising_steps: 50 +save: + name: "temp" + path: "./models" + per_steps: 500 + precision: "bfloat16" +logging: + use_wandb: false + verbose: false +other: + use_xformers: true \ No newline at end of file diff --git a/trainscripts/textsliders/data/config.yaml b/trainscripts/textsliders/data/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b1cbed713d24e74615e259a76ffc365ac396d7a --- /dev/null +++ b/trainscripts/textsliders/data/config.yaml @@ -0,0 +1,28 @@ +prompts_file: "trainscripts/textsliders/data/prompts.yaml" +pretrained_model: + name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models + v2: false # true if model is v2.x + v_pred: false # true if model uses v-prediction +network: + type: "c3lier" # or "c3lier" or "lierla" + rank: 4 + alpha: 1.0 + training_method: "noxattn" +train: + precision: "bfloat16" + noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" + iterations: 1000 + lr: 0.0002 + optimizer: "AdamW" + lr_scheduler: "constant" + max_denoising_steps: 50 +save: + name: "temp" + path: "./models" + per_steps: 500 + precision: "bfloat16" +logging: + use_wandb: false + verbose: false +other: + use_xformers: true \ No newline at end of file diff --git a/trainscripts/textsliders/data/prompts-xl.yaml b/trainscripts/textsliders/data/prompts-xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..093f57b68818b3bbf8aa5d81ea047e2983e36cf4 --- /dev/null +++ b/trainscripts/textsliders/data/prompts-xl.yaml @@ -0,0 +1,477 @@ +####################################################################################################### AGE SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, very old" # concept to erase +# unconditional: "male person, very young" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, very old" # concept to erase +# unconditional: "female person, very young" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### MUSCULAR SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, muscular, strong, biceps, greek god physique, body builder" # concept to erase +# unconditional: "male person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, muscular, strong, biceps, greek god physique, body builder" # concept to erase +# unconditional: "female person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### CURLY HAIR SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, curly hair, wavy hair" # concept to erase +# unconditional: "male person, straight hair" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, curly hair, wavy hair" # concept to erase +# unconditional: "female person, straight hair" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### BEARD SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with beard" # concept to erase +# unconditional: "male person, clean shaven" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with beard, lipstick and feminine" # concept to erase +# unconditional: "female person, clean shaven" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### MAKEUP SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with makeup, cosmetic, concealer, mascara" # concept to erase +# unconditional: "male person, barefaced, ugly" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with makeup, cosmetic, concealer, mascara, lipstick" # concept to erase +# unconditional: "female person, barefaced, ugly" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### SURPRISED SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with shocked look, surprised, stunned, amazed" # concept to erase +# unconditional: "male person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with shocked look, surprised, stunned, amazed" # concept to erase +# unconditional: "female person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### OBESE SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, fat, chubby, overweight, obese" # concept to erase +# unconditional: "male person, lean, fit, slim, slender" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, fat, chubby, overweight, obese" # concept to erase +# unconditional: "female person, lean, fit, slim, slender" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### PROFESSIONAL SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, professionally dressed, stylised hair, clean face" # concept to erase +# unconditional: "male person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, professionally dressed, stylised hair, clean face" # concept to erase +# unconditional: "female person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### GLASSES SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, wearing glasses" # concept to erase +# unconditional: "male person" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, wearing glasses" # concept to erase +# unconditional: "female person" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### ASTRONAUGHT SLIDER +# - target: "astronaught" # what word for erasing the positive concept from +# positive: "astronaught, with orange colored spacesuit" # concept to erase +# unconditional: "astronaught" # word to take the difference from the positive concept +# neutral: "astronaught" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### SMILING SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, smiling" # concept to erase +# unconditional: "male person, frowning" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, smiling" # concept to erase +# unconditional: "female person, frowning" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### CAR COLOR SLIDER +# - target: "car" # what word for erasing the positive concept from +# positive: "car, white color" # concept to erase +# unconditional: "car, black color" # word to take the difference from the positive concept +# neutral: "car" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### DETAILS SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase +# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### CARTOON SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, cartoon style, pixar style, animated style" # concept to erase +# unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, cartoon style, pixar style, animated style" # concept to erase +# unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### CLAY SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, clay style, made out of clay, clay sculpture" # concept to erase +# unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, clay style, made out of clay, clay sculpture" # concept to erase +# unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### SCULPTURE SLIDER +- target: "male person" # what word for erasing the positive concept from + positive: "male person, cement sculpture, cement greek statue style" # concept to erase + unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept + neutral: "male person" # starting point for conditioning the target + action: "enhance" # erase or enhance + guidance_scale: 4 + resolution: 512 + dynamic_resolution: false + batch_size: 1 +- target: "female person" # what word for erasing the positive concept from + positive: "female person, cement sculpture, cement greek statue style" # concept to erase + unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept + neutral: "female person" # starting point for conditioning the target + action: "enhance" # erase or enhance + guidance_scale: 4 + resolution: 512 + dynamic_resolution: false + batch_size: 1 +####################################################################################################### METAL SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "made out of metal, metallic style, iron, copper, platinum metal," # concept to erase +# unconditional: "wooden style, made out of wood" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### FESTIVE SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "festive, colorful banners, confetti, indian festival decorations, chinese festival decorations, fireworks, parade, cherry, gala, happy, celebrations" # concept to erase +# unconditional: "dull, dark, sad, desserted, empty, alone" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### TROPICAL SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "tropical, beach, sunny, hot" # concept to erase +# unconditional: "arctic, winter, snow, ice, iceburg, snowfall" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### MODERN SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "modern, futuristic style, trendy, stylish, swank" # concept to erase +# unconditional: "ancient, classic style, regal, vintage" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### BOKEH SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "blurred background, narrow DOF, bokeh effect" # concept to erase +# # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept +# unconditional: "" +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### LONG HAIR SLIDER +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with long hair" # concept to erase +# unconditional: "male person, with short hair" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with long hair" # concept to erase +# unconditional: "female person, with short hair" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### NEGPROMPT SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "cartoon, cgi, render, illustration, painting, drawing, bad quality, grainy, low resolution" # concept to erase +# unconditional: "" +# neutral: "" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### EXPENSIVE FOOD SLIDER +# - target: "food" # what word for erasing the positive concept from +# positive: "food, expensive and fine dining" # concept to erase +# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept +# neutral: "food" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### COOKED FOOD SLIDER +# - target: "food" # what word for erasing the positive concept from +# positive: "food, cooked, baked, roasted, fried" # concept to erase +# unconditional: "food, raw, uncooked, fresh, undone" # word to take the difference from the positive concept +# neutral: "food" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### MEAT FOOD SLIDER +# - target: "food" # what word for erasing the positive concept from +# positive: "food, meat, steak, fish, non-vegetrian, beef, lamb, pork, chicken, salmon" # concept to erase +# unconditional: "food, vegetables, fruits, leafy-vegetables, greens, vegetarian, vegan, tomatoes, onions, carrots" # word to take the difference from the positive concept +# neutral: "food" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### WEATHER SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "snowy, winter, cold, ice, snowfall, white" # concept to erase +# unconditional: "hot, summer, bright, sunny" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### NIGHT/DAY SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "night time, dark, darkness, pitch black, nighttime" # concept to erase +# unconditional: "day time, bright, sunny, daytime, sunlight" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### INDOOR/OUTDOOR SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "indoor, inside a room, inside, interior" # concept to erase +# unconditional: "outdoor, outside, open air, exterior" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### GOODHANDS SLIDER +# - target: "" # what word for erasing the positive concept from +# positive: "realistic hands, realistic limbs, perfect limbs, perfect hands, 5 fingers, five fingers, hyper realisitc hands" # concept to erase +# unconditional: "poorly drawn limbs, distorted limbs, poorly rendered hands,bad anatomy, disfigured, mutated body parts, bad composition" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### RUSTY CAR SLIDER +# - target: "car" # what word for erasing the positive concept from +# positive: "car, rusty conditioned" # concept to erase +# unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept +# neutral: "car" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### RUSTY CAR SLIDER +# - target: "car" # what word for erasing the positive concept from +# positive: "car, damaged, broken headlights, dented car, with scrapped paintwork" # concept to erase +# unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept +# neutral: "car" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### CLUTTERED ROOM SLIDER +# - target: "room" # what word for erasing the positive concept from +# positive: "room, cluttered, disorganized, dirty, jumbled, scattered" # concept to erase +# unconditional: "room, super organized, clean, ordered, neat, tidy" # word to take the difference from the positive concept +# neutral: "room" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### HANDS SLIDER +# - target: "hands" # what word for erasing the positive concept from +# positive: "realistic hands, five fingers, 8k hyper realistic hands" # concept to erase +# unconditional: "poorly drawn hands, distorted hands, amputed fingers" # word to take the difference from the positive concept +# neutral: "hands" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +####################################################################################################### HANDS SLIDER +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with a surprised look" # concept to erase +# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 \ No newline at end of file diff --git a/trainscripts/textsliders/data/prompts.yaml b/trainscripts/textsliders/data/prompts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2078e6f2790769bd5957faa7340a64309519f979 --- /dev/null +++ b/trainscripts/textsliders/data/prompts.yaml @@ -0,0 +1,193 @@ +- target: "male person" # what word for erasing the positive concept from + positive: "male person, very old" # concept to erase + unconditional: "male person, very young" # word to take the difference from the positive concept + neutral: "male person" # starting point for conditioning the target + action: "enhance" # erase or enhance + guidance_scale: 4 + resolution: 512 + dynamic_resolution: false + batch_size: 1 +- target: "female person" # what word for erasing the positive concept from + positive: "female person, very old" # concept to erase + unconditional: "female person, very young" # word to take the difference from the positive concept + neutral: "female person" # starting point for conditioning the target + action: "enhance" # erase or enhance + guidance_scale: 4 + resolution: 512 + dynamic_resolution: false + batch_size: 1 +# - target: "" # what word for erasing the positive concept from +# positive: "a group of people" # concept to erase +# unconditional: "a person" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "" # what word for erasing the positive concept from +# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase +# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "" # what word for erasing the positive concept from +# positive: "blurred background, narrow DOF, bokeh effect" # concept to erase +# # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept +# unconditional: "" +# neutral: "" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "food" # what word for erasing the positive concept from +# positive: "food, expensive and fine dining" # concept to erase +# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept +# neutral: "food" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "room" # what word for erasing the positive concept from +# positive: "room, dirty disorganised and cluttered" # concept to erase +# unconditional: "room, neat organised and clean" # word to take the difference from the positive concept +# neutral: "room" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "male person" # what word for erasing the positive concept from +# positive: "male person, with a surprised look" # concept to erase +# unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept +# neutral: "male person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "female person" # what word for erasing the positive concept from +# positive: "female person, with a surprised look" # concept to erase +# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept +# neutral: "female person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "sky" # what word for erasing the positive concept from +# positive: "peaceful sky" # concept to erase +# unconditional: "sky" # word to take the difference from the positive concept +# neutral: "sky" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "sky" # what word for erasing the positive concept from +# positive: "chaotic dark sky" # concept to erase +# unconditional: "sky" # word to take the difference from the positive concept +# neutral: "sky" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "person" # what word for erasing the positive concept from +# positive: "person, very young" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# overweight +# - target: "art" # what word for erasing the positive concept from +# positive: "realistic art" # concept to erase +# unconditional: "art" # word to take the difference from the positive concept +# neutral: "art" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "art" # what word for erasing the positive concept from +# positive: "abstract art" # concept to erase +# unconditional: "art" # word to take the difference from the positive concept +# neutral: "art" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# sky +# - target: "weather" # what word for erasing the positive concept from +# positive: "bright pleasant weather" # concept to erase +# unconditional: "weather" # word to take the difference from the positive concept +# neutral: "weather" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "weather" # what word for erasing the positive concept from +# positive: "dark gloomy weather" # concept to erase +# unconditional: "weather" # word to take the difference from the positive concept +# neutral: "weather" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# hair +# - target: "person" # what word for erasing the positive concept from +# positive: "person with long hair" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "person" # what word for erasing the positive concept from +# positive: "person with short hair" # concept to erase +# unconditional: "person" # word to take the difference from the positive concept +# neutral: "person" # starting point for conditioning the target +# action: "erase" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "girl" # what word for erasing the positive concept from +# positive: "baby girl" # concept to erase +# unconditional: "girl" # word to take the difference from the positive concept +# neutral: "girl" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: -4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "boy" # what word for erasing the positive concept from +# positive: "old man" # concept to erase +# unconditional: "boy" # word to take the difference from the positive concept +# neutral: "boy" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: 4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 +# - target: "boy" # what word for erasing the positive concept from +# positive: "baby boy" # concept to erase +# unconditional: "boy" # word to take the difference from the positive concept +# neutral: "boy" # starting point for conditioning the target +# action: "enhance" # erase or enhance +# guidance_scale: -4 +# resolution: 512 +# dynamic_resolution: false +# batch_size: 1 \ No newline at end of file diff --git a/trainscripts/textsliders/debug_util.py b/trainscripts/textsliders/debug_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5b115093b0ad6625db2f9aedaaa90924cf7de468 --- /dev/null +++ b/trainscripts/textsliders/debug_util.py @@ -0,0 +1,16 @@ +# デバッグ用... + +import torch + + +def check_requires_grad(model: torch.nn.Module): + for name, module in list(model.named_modules())[:5]: + if len(list(module.parameters())) > 0: + print(f"Module: {name}") + for name, param in list(module.named_parameters())[:2]: + print(f" Parameter: {name}, Requires Grad: {param.requires_grad}") + + +def check_training_mode(model: torch.nn.Module): + for name, module in list(model.named_modules())[:5]: + print(f"Module: {name}, Training Mode: {module.training}") diff --git a/trainscripts/textsliders/flush.py b/trainscripts/textsliders/flush.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc03f21aa0cf4050f94a096c253f17608600c5d --- /dev/null +++ b/trainscripts/textsliders/flush.py @@ -0,0 +1,5 @@ +import torch +import gc + +torch.cuda.empty_cache() +gc.collect() diff --git a/trainscripts/textsliders/generate_images_xl.py b/trainscripts/textsliders/generate_images_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..8931e07c05a1ae9896de8829eb9c9f8a777b9410 --- /dev/null +++ b/trainscripts/textsliders/generate_images_xl.py @@ -0,0 +1,513 @@ +import torch +from PIL import Image +import argparse +import os, json, random +import pandas as pd +import matplotlib.pyplot as plt +import glob, re + +from safetensors.torch import load_file +import matplotlib.image as mpimg +import copy +import gc +from transformers import CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import DiffusionPipeline +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor +from typing import Any, Dict, List, Optional, Tuple, Union +from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +import inspect +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from diffusers.pipelines import StableDiffusionXLPipeline +import random + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +import re +import argparse + +def flush(): + torch.cuda.empty_cache() + gc.collect() + +@torch.no_grad() +def call( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + + network=None, + start_noise=None, + scale=None, + unet=None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if t>start_noise: + network.set_lora_slider(scale=0) + else: + network.set_lora_slider(scale=scale) + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + with network: + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models +# self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + +def sorted_nicely( l ): + convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text + alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ] + return sorted(l, key = alphanum_key) + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +if __name__=='__main__': + + device = 'cuda:0' + StableDiffusionXLPipeline.__call__ = call + pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0') + + # pipe.__call__ = call + pipe = pipe.to(device) + + + parser = argparse.ArgumentParser( + prog = 'generateImages', + description = 'Generate Images using Diffusers Code') + parser.add_argument('--model_name', help='name of model', type=str, required=True) + parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True) + parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None) + parser.add_argument('--save_path', help='folder where to save images', type=str, required=True) + parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4') + parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5) + parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512) + parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000) + parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0) + parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=5) + parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50) + parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4) + parser.add_argument('--start_noise', help='what time stamp to flip to edited model', type=int, required=False, default=750) + + args = parser.parse_args() + lora_weight = args.model_name + csv_path = args.prompts_path + save_path = args.save_path + start_noise = args.start_noise + from_case = args.from_case + till_case = args.till_case + + weight_dtype = torch.float16 + num_images_per_prompt = 1 + scales = [-2, -1, 0, 1, 2] + scales = [-1, -.5, 0, .5, 1] + scales = [-2] + df = pd.read_csv(csv_path) + + for scale in scales: + os.makedirs(f'{save_path}/{os.path.basename(lora_weight)}/{scale}', exist_ok=True) + + prompts = list(df['prompt']) + seeds = list(df['evaluation_seed']) + case_numbers = list(df['case_number']) + pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',torch_dtype=torch.float16,) + + # pipe.__call__ = call + pipe = pipe.to(device) + unet = pipe.unet + if 'full' in lora_weight: + train_method = 'full' + elif 'noxattn' in lora_weight: + train_method = 'noxattn' + else: + train_method = 'noxattn' + + network_type = "c3lier" + if train_method == 'xattn': + network_type = 'lierla' + + modules = DEFAULT_TARGET_REPLACE + if network_type == "c3lier": + modules += UNET_TARGET_REPLACE_MODULE_CONV + import os + model_name = lora_weight + + name = os.path.basename(model_name) + rank = 1 + alpha = 4 + if 'rank4' in lora_weight: + rank = 4 + if 'rank8' in lora_weight: + rank = 8 + if 'alpha1' in lora_weight: + alpha = 1.0 + network = LoRANetwork( + unet, + rank=rank, + multiplier=1.0, + alpha=alpha, + train_method=train_method, + ).to(device, dtype=weight_dtype) + network.load_state_dict(torch.load(lora_weight)) + + for idx, prompt in enumerate(prompts): + seed = seeds[idx] + case_number = case_numbers[idx] + + if not (case_number>=from_case and case_number<=till_case): + continue + if os.path.exists(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png'): + continue + print(prompt, seed) + for scale in scales: + generator = torch.manual_seed(seed) + images = pipe(prompt, num_images_per_prompt=args.num_samples, num_inference_steps=50, generator=generator, network=network, start_noise=start_noise, scale=scale, unet=unet).images + for idx, im in enumerate(images): + im.save(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png') + del unet, network, pipe + unet = None + network = None + pipe = None + torch.cuda.empty_cache() + flush() \ No newline at end of file diff --git a/trainscripts/textsliders/lora.py b/trainscripts/textsliders/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..98f6e55ac0c08c7d45fe30b94d260ec52133e215 --- /dev/null +++ b/trainscripts/textsliders/lora.py @@ -0,0 +1,258 @@ +# ref: +# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py + +import os +import math +from typing import Optional, List, Type, Set, Literal + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from safetensors.torch import save_file + + +UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ +# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 + "Attention" +] +UNET_TARGET_REPLACE_MODULE_CONV = [ + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", + "DownBlock2D", + "UpBlock2D", + +] # locon, 3clier + +LORA_PREFIX_UNET = "lora_unet" + +DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER + +TRAINING_METHODS = Literal[ + "noxattn", # train all layers except x-attns and time_embed layers + "innoxattn", # train all layers except self attention layers + "selfattn", # ESD-u, train only self attention layers + "xattn", # ESD-x, train only x attention layers + "full", # train all layers + "xattn-strict", # q and k values + "noxattn-hspace", + "noxattn-hspace-last", + # "xlayer", + # "outxattn", + # "outsattn", + # "inxattn", + # "inmidsattn", + # "selflayer", +] + + +class LoRAModule(nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + + if "Linear" in org_module.__class__.__name__: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) + + elif "Conv" in org_module.__class__.__name__: # 一応 + in_dim = org_module.in_channels + out_dim = org_module.out_channels + + self.lora_dim = min(self.lora_dim, in_dim, out_dim) + if self.lora_dim != lora_dim: + print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = nn.Conv2d( + in_dim, self.lora_dim, kernel_size, stride, padding, bias=False + ) + self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + nn.init.kaiming_uniform_(self.lora_down.weight, a=1) + nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + return ( + self.org_forward(x) + + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + ) + + +class LoRANetwork(nn.Module): + def __init__( + self, + unet: UNet2DConditionModel, + rank: int = 4, + multiplier: float = 1.0, + alpha: float = 1.0, + train_method: TRAINING_METHODS = "full", + ) -> None: + super().__init__() + self.lora_scale = 1 + self.multiplier = multiplier + self.lora_dim = rank + self.alpha = alpha + + # LoRAのみ + self.module = LoRAModule + + # unetのloraを作る + self.unet_loras = self.create_modules( + LORA_PREFIX_UNET, + unet, + DEFAULT_TARGET_REPLACE, + self.lora_dim, + self.multiplier, + train_method=train_method, + ) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + # assertion 名前の被りがないか確認しているようだ + lora_names = set() + for lora in self.unet_loras: + assert ( + lora.lora_name not in lora_names + ), f"duplicated lora name: {lora.lora_name}. {lora_names}" + lora_names.add(lora.lora_name) + + # 適用する + for lora in self.unet_loras: + lora.apply_to() + self.add_module( + lora.lora_name, + lora, + ) + + del unet + + torch.cuda.empty_cache() + + def create_modules( + self, + prefix: str, + root_module: nn.Module, + target_replace_modules: List[str], + rank: int, + multiplier: float, + train_method: TRAINING_METHODS, + ) -> list: + loras = [] + names = [] + for name, module in root_module.named_modules(): + if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習 + if "attn2" in name or "time_embed" in name: + continue + elif train_method == "innoxattn": # Cross Attention 以外学習 + if "attn2" in name: + continue + elif train_method == "selfattn": # Self Attention のみ学習 + if "attn1" not in name: + continue + elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習 + if "attn2" not in name: + continue + elif train_method == "full": # 全部学習 + pass + else: + raise NotImplementedError( + f"train_method: {train_method} is not implemented." + ) + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]: + if train_method == 'xattn-strict': + if 'out' in child_name: + continue + if train_method == 'noxattn-hspace': + if 'mid_block' not in name: + continue + if train_method == 'noxattn-hspace-last': + if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name: + continue + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") +# print(f"{lora_name}") + lora = self.module( + lora_name, child_module, multiplier, rank, self.alpha + ) +# print(name, child_name) +# print(child_module.weight.shape) + if lora_name not in names: + loras.append(lora) + names.append(lora_name) +# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}') + return loras + + def prepare_optimizer_params(self): + all_params = [] + + if self.unet_loras: # 実質これしかない + params = [] + [params.extend(lora.parameters()) for lora in self.unet_loras] + param_data = {"params": params} + all_params.append(param_data) + + return all_params + + def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + +# for key in list(state_dict.keys()): +# if not key.startswith("lora"): +# # lora以外除外 +# del state_dict[key] + + if os.path.splitext(file)[1] == ".safetensors": + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + def set_lora_slider(self, scale): + self.lora_scale = scale + + def __enter__(self): + for lora in self.unet_loras: + lora.multiplier = 1.0 * self.lora_scale + + def __exit__(self, exc_type, exc_value, tb): + for lora in self.unet_loras: + lora.multiplier = 0 diff --git a/trainscripts/textsliders/model_util.py b/trainscripts/textsliders/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8a51f837e5b0c5e10d4ca4ca4244f16456518bd1 --- /dev/null +++ b/trainscripts/textsliders/model_util.py @@ -0,0 +1,278 @@ +from typing import Literal, Union, Optional + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from diffusers import ( + UNet2DConditionModel, + SchedulerMixin, + StableDiffusionPipeline, + StableDiffusionXLPipeline, +) +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + LMSDiscreteScheduler, + EulerAncestralDiscreteScheduler, +) + + +TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" +TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" + +AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"] + +SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] + +DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this + + +def load_diffusers_model( + pretrained_model_name_or_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + # VAE はいらない + + if v2: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V2_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + # default is clip skip 2 + num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + else: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V1_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + return tokenizer, text_encoder, unet + + +def load_checkpoint_model( + checkpoint_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + pipe = StableDiffusionPipeline.from_ckpt( + checkpoint_path, + upcast_attention=True if v2 else False, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + if clip_skip is not None: + if v2: + text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) + else: + text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) + + del pipe + + return tokenizer, text_encoder, unet + + +def load_models( + pretrained_model_name_or_path: str, + scheduler_name: AVAILABLE_SCHEDULERS, + v2: bool = False, + v_pred: bool = False, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + tokenizer, text_encoder, unet = load_checkpoint_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + else: # diffusers + tokenizer, text_encoder, unet = load_diffusers_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + + # VAE はいらない + + scheduler = create_noise_scheduler( + scheduler_name, + prediction_type="v_prediction" if v_pred else "epsilon", + ) + + return tokenizer, text_encoder, unet, scheduler + + +def load_diffusers_model_xl( + pretrained_model_name_or_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet + + tokenizers = [ + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + pad_token_id=0, # same as open clip + ), + ] + + text_encoders = [ + CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTextModelWithProjection.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + ] + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + return tokenizers, text_encoders, unet + + +def load_checkpoint_model_xl( + checkpoint_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + pipe = StableDiffusionXLPipeline.from_single_file( + checkpoint_path, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + if len(text_encoders) == 2: + text_encoders[1].pad_token_id = 0 + + del pipe + + return tokenizers, text_encoders, unet + + +def load_models_xl( + pretrained_model_name_or_path: str, + scheduler_name: AVAILABLE_SCHEDULERS, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[ + list[CLIPTokenizer], + list[SDXL_TEXT_ENCODER_TYPE], + UNet2DConditionModel, + SchedulerMixin, +]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + ( + tokenizers, + text_encoders, + unet, + ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype) + else: # diffusers + ( + tokenizers, + text_encoders, + unet, + ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype) + + scheduler = create_noise_scheduler(scheduler_name) + + return tokenizers, text_encoders, unet, scheduler + + +def create_noise_scheduler( + scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", +) -> SchedulerMixin: + # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。 + + name = scheduler_name.lower().replace(" ", "_") + if name == "ddim": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + prediction_type=prediction_type, # これでいいの? + ) + elif name == "ddpm": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm + scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + prediction_type=prediction_type, + ) + elif name == "lms": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete + scheduler = LMSDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + prediction_type=prediction_type, + ) + elif name == "euler_a": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral + scheduler = EulerAncestralDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + prediction_type=prediction_type, + ) + else: + raise ValueError(f"Unknown scheduler name: {name}") + + return scheduler diff --git a/trainscripts/textsliders/prompt_util.py b/trainscripts/textsliders/prompt_util.py new file mode 100644 index 0000000000000000000000000000000000000000..27de3eaa42044eba09b0917e694b441fb72a4ef7 --- /dev/null +++ b/trainscripts/textsliders/prompt_util.py @@ -0,0 +1,174 @@ +from typing import Literal, Optional, Union, List + +import yaml +from pathlib import Path + + +from pydantic import BaseModel, root_validator +import torch +import copy + +ACTION_TYPES = Literal[ + "erase", + "enhance", +] + + +# XL は二種類必要なので +class PromptEmbedsXL: + text_embeds: torch.FloatTensor + pooled_embeds: torch.FloatTensor + + def __init__(self, *args) -> None: + self.text_embeds = args[0] + self.pooled_embeds = args[1] + + +# SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL +PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL] + + +class PromptEmbedsCache: # 使いまわしたいので + prompts: dict[str, PROMPT_EMBEDDING] = {} + + def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None: + self.prompts[__name] = __value + + def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]: + if __name in self.prompts: + return self.prompts[__name] + else: + return None + + +class PromptSettings(BaseModel): # yaml のやつ + target: str + positive: str = None # if None, target will be used + unconditional: str = "" # default is "" + neutral: str = None # if None, unconditional will be used + action: ACTION_TYPES = "erase" # default is "erase" + guidance_scale: float = 1.0 # default is 1.0 + resolution: int = 512 # default is 512 + dynamic_resolution: bool = False # default is False + batch_size: int = 1 # default is 1 + dynamic_crops: bool = False # default is False. only used when model is XL + + @root_validator(pre=True) + def fill_prompts(cls, values): + keys = values.keys() + if "target" not in keys: + raise ValueError("target must be specified") + if "positive" not in keys: + values["positive"] = values["target"] + if "unconditional" not in keys: + values["unconditional"] = "" + if "neutral" not in keys: + values["neutral"] = values["unconditional"] + + return values + + +class PromptEmbedsPair: + target: PROMPT_EMBEDDING # not want to generate the concept + positive: PROMPT_EMBEDDING # generate the concept + unconditional: PROMPT_EMBEDDING # uncondition (default should be empty) + neutral: PROMPT_EMBEDDING # base condition (default should be empty) + + guidance_scale: float + resolution: int + dynamic_resolution: bool + batch_size: int + dynamic_crops: bool + + loss_fn: torch.nn.Module + action: ACTION_TYPES + + def __init__( + self, + loss_fn: torch.nn.Module, + target: PROMPT_EMBEDDING, + positive: PROMPT_EMBEDDING, + unconditional: PROMPT_EMBEDDING, + neutral: PROMPT_EMBEDDING, + settings: PromptSettings, + ) -> None: + self.loss_fn = loss_fn + self.target = target + self.positive = positive + self.unconditional = unconditional + self.neutral = neutral + + self.guidance_scale = settings.guidance_scale + self.resolution = settings.resolution + self.dynamic_resolution = settings.dynamic_resolution + self.batch_size = settings.batch_size + self.dynamic_crops = settings.dynamic_crops + self.action = settings.action + + def _erase( + self, + target_latents: torch.FloatTensor, # "van gogh" + positive_latents: torch.FloatTensor, # "van gogh" + unconditional_latents: torch.FloatTensor, # "" + neutral_latents: torch.FloatTensor, # "" + ) -> torch.FloatTensor: + """Target latents are going not to have the positive concept.""" + return self.loss_fn( + target_latents, + neutral_latents + - self.guidance_scale * (positive_latents - unconditional_latents) + ) + + + def _enhance( + self, + target_latents: torch.FloatTensor, # "van gogh" + positive_latents: torch.FloatTensor, # "van gogh" + unconditional_latents: torch.FloatTensor, # "" + neutral_latents: torch.FloatTensor, # "" + ): + """Target latents are going to have the positive concept.""" + return self.loss_fn( + target_latents, + neutral_latents + + self.guidance_scale * (positive_latents - unconditional_latents) + ) + + def loss( + self, + **kwargs, + ): + if self.action == "erase": + return self._erase(**kwargs) + + elif self.action == "enhance": + return self._enhance(**kwargs) + + else: + raise ValueError("action must be erase or enhance") + + +def load_prompts_from_yaml(path, attributes = []): + with open(path, "r") as f: + prompts = yaml.safe_load(f) + print(prompts) + if len(prompts) == 0: + raise ValueError("prompts file is empty") + if len(attributes)!=0: + newprompts = [] + for i in range(len(prompts)): + for att in attributes: + copy_ = copy.deepcopy(prompts[i]) + copy_['target'] = att + ' ' + copy_['target'] + copy_['positive'] = att + ' ' + copy_['positive'] + copy_['neutral'] = att + ' ' + copy_['neutral'] + copy_['unconditional'] = att + ' ' + copy_['unconditional'] + newprompts.append(copy_) + else: + newprompts = copy.deepcopy(prompts) + + print(newprompts) + print(len(prompts), len(newprompts)) + prompt_settings = [PromptSettings(**prompt) for prompt in newprompts] + + return prompt_settings diff --git a/trainscripts/textsliders/ptp_utils.py b/trainscripts/textsliders/ptp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c64b0867d9dabbcf8a440664749243bfed32c389 --- /dev/null +++ b/trainscripts/textsliders/ptp_utils.py @@ -0,0 +1,295 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +import cv2 +from typing import Optional, Union, Tuple, List, Callable, Dict +from IPython.display import display +from tqdm.notebook import tqdm + + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) + return img + + +def view_images(images, num_rows=1, offset_ratio=0.02): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + pil_img = Image.fromarray(image_) + display(pil_img) + + +def diffusion_step(unet, model, controller, latents, context, t, guidance_scale, low_resource=False): + if low_resource: + noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] + noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] + else: + latents_input = torch.cat([latents] * 2) + noise_pred = unet(latents_input, t, encoder_hidden_states=context)["sample"] + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] + latents = controller.step_callback(latents) + return latents + + +def latent2image(vae, latents): + latents = 1 / 0.18215 * latents + image = vae.decode(latents)['sample'] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image * 255).astype(np.uint8) + return image + + +def init_latent(latent, model, height, width, generator, batch_size): + if latent is None: + latent = torch.randn( + (1, model.unet.in_channels, height // 8, width // 8), + generator=generator, + ) + latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) + return latent, latents + + +@torch.no_grad() +def text2image_ldm( + model, + prompt: List[str], + controller, + num_inference_steps: int = 50, + guidance_scale: Optional[float] = 7., + generator: Optional[torch.Generator] = None, + latent: Optional[torch.FloatTensor] = None, +): + register_attention_control(model, controller) + height = width = 256 + batch_size = len(prompt) + + uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] + + text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] + latent, latents = init_latent(latent, model, height, width, generator, batch_size) + context = torch.cat([uncond_embeddings, text_embeddings]) + + model.scheduler.set_timesteps(num_inference_steps) + for t in tqdm(model.scheduler.timesteps): + latents = diffusion_step(model, controller, latents, context, t, guidance_scale) + + image = latent2image(model.vqvae, latents) + + return image, latent + + +@torch.no_grad() +def text2image_ldm_stable( + model, + prompt: List[str], + controller, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: Optional[torch.Generator] = None, + latent: Optional[torch.FloatTensor] = None, + low_resource: bool = False, +): + register_attention_control(model, controller) + height = width = 512 + batch_size = len(prompt) + + text_input = model.tokenizer( + prompt, + padding="max_length", + max_length=model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] + max_length = text_input.input_ids.shape[-1] + uncond_input = model.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] + + context = [uncond_embeddings, text_embeddings] + if not low_resource: + context = torch.cat(context) + latent, latents = init_latent(latent, model, height, width, generator, batch_size) + + # set timesteps + extra_set_kwargs = {"offset": 1} + model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + for t in tqdm(model.scheduler.timesteps): + latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) + + image = latent2image(model.vae, latents) + + return image, latent + + +def register_attention_control(model, controller): + def ca_forward(self, place_in_unet): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + + def forward(x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + h = self.heads + q = self.to_q(x) + is_cross = context is not None + context = context if is_cross else x + k = self.to_k(context) + v = self.to_v(context) + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + mask = mask.reshape(batch_size, -1) + max_neg_value = -torch.finfo(sim.dtype).max + mask = mask[:, None, :].repeat(h, 1, 1) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + attn = controller(attn, is_cross, place_in_unet) + out = torch.einsum("b i j, b j d -> b i d", attn, v) + out = self.reshape_batch_dim_to_heads(out) + return to_out(out) + + return forward + + class DummyController: + + def __call__(self, *args): + return args[0] + + def __init__(self): + self.num_att_layers = 0 + + if controller is None: + controller = DummyController() + + def register_recr(net_, count, place_in_unet): + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + elif hasattr(net_, 'children'): + for net__ in net_.children(): + count = register_recr(net__, count, place_in_unet) + return count + + cross_att_count = 0 + sub_nets = model.unet.named_children() + for net in sub_nets: + if "down" in net[0]: + cross_att_count += register_recr(net[1], 0, "down") + elif "up" in net[0]: + cross_att_count += register_recr(net[1], 0, "up") + elif "mid" in net[0]: + cross_att_count += register_recr(net[1], 0, "mid") + + controller.num_att_layers = cross_att_count + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, + word_inds: Optional[torch.Tensor]=None): + if type(bounds) is float: + bounds = 0, bounds + start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) + if word_inds is None: + word_inds = torch.arange(alpha.shape[2]) + alpha[: start, prompt_ind, word_inds] = 0 + alpha[start: end, prompt_ind, word_inds] = 1 + alpha[end:, prompt_ind, word_inds] = 0 + return alpha + + +def get_time_words_attention_alpha(prompts, num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, max_num_words=77): + if type(cross_replace_steps) is not dict: + cross_replace_steps = {"default_": cross_replace_steps} + if "default_" not in cross_replace_steps: + cross_replace_steps["default_"] = (0., 1.) + alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) + for i in range(len(prompts) - 1): + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], + i) + for key, item in cross_replace_steps.items(): + if key != "default_": + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + for i, ind in enumerate(inds): + if len(ind) > 0: + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + return alpha_time_words \ No newline at end of file diff --git a/trainscripts/textsliders/train_lora.py b/trainscripts/textsliders/train_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1f79555f6d54d02e7b2d33cf0d538447fc33a9 --- /dev/null +++ b/trainscripts/textsliders/train_lora.py @@ -0,0 +1,419 @@ +# ref: +# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 +# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py + +from typing import List, Optional +import argparse +import ast +from pathlib import Path +import gc + +import torch +from tqdm import tqdm + + +from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV +import train_util +import model_util +import prompt_util +from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings +import debug_util +import config_util +from config_util import RootConfig + +import wandb + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +def train( + config: RootConfig, + prompts: list[PromptSettings], + device: int +): + + metadata = { + "prompts": ",".join([prompt.json() for prompt in prompts]), + "config": config.json(), + } + save_path = Path(config.save.path) + + modules = DEFAULT_TARGET_REPLACE + if config.network.type == "c3lier": + modules += UNET_TARGET_REPLACE_MODULE_CONV + + if config.logging.verbose: + print(metadata) + + if config.logging.use_wandb: + wandb.init(project=f"LECO_{config.save.name}", config=metadata) + + weight_dtype = config_util.parse_precision(config.train.precision) + save_weight_dtype = config_util.parse_precision(config.train.precision) + + tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( + config.pretrained_model.name_or_path, + scheduler_name=config.train.noise_scheduler, + v2=config.pretrained_model.v2, + v_pred=config.pretrained_model.v_pred, + ) + + text_encoder.to(device, dtype=weight_dtype) + text_encoder.eval() + + unet.to(device, dtype=weight_dtype) + unet.enable_xformers_memory_efficient_attention() + unet.requires_grad_(False) + unet.eval() + + network = LoRANetwork( + unet, + rank=config.network.rank, + multiplier=1.0, + alpha=config.network.alpha, + train_method=config.network.training_method, + ).to(device, dtype=weight_dtype) + + optimizer_module = train_util.get_optimizer(config.train.optimizer) + #optimizer_args + optimizer_kwargs = {} + if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: + for arg in config.train.optimizer_args.split(" "): + key, value = arg.split("=") + value = ast.literal_eval(value) + optimizer_kwargs[key] = value + + optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) + lr_scheduler = train_util.get_lr_scheduler( + config.train.lr_scheduler, + optimizer, + max_iterations=config.train.iterations, + lr_min=config.train.lr / 100, + ) + criteria = torch.nn.MSELoss() + + print("Prompts") + for settings in prompts: + print(settings) + + # debug + debug_util.check_requires_grad(network) + debug_util.check_training_mode(network) + + cache = PromptEmbedsCache() + prompt_pairs: list[PromptEmbedsPair] = [] + + with torch.no_grad(): + for settings in prompts: + print(settings) + for prompt in [ + settings.target, + settings.positive, + settings.neutral, + settings.unconditional, + ]: + print(prompt) + if isinstance(prompt, list): + if prompt == settings.positive: + key_setting = 'positive' + else: + key_setting = 'attributes' + if len(prompt) == 0: + cache[key_setting] = [] + else: + if cache[key_setting] is None: + cache[key_setting] = train_util.encode_prompts( + tokenizer, text_encoder, prompt + ) + else: + if cache[prompt] == None: + cache[prompt] = train_util.encode_prompts( + tokenizer, text_encoder, [prompt] + ) + + prompt_pairs.append( + PromptEmbedsPair( + criteria, + cache[settings.target], + cache[settings.positive], + cache[settings.unconditional], + cache[settings.neutral], + settings, + ) + ) + + del tokenizer + del text_encoder + + flush() + + pbar = tqdm(range(config.train.iterations)) + + for i in pbar: + with torch.no_grad(): + noise_scheduler.set_timesteps( + config.train.max_denoising_steps, device=device + ) + + optimizer.zero_grad() + + prompt_pair: PromptEmbedsPair = prompt_pairs[ + torch.randint(0, len(prompt_pairs), (1,)).item() + ] + + # 1 ~ 49 からランダム + timesteps_to = torch.randint( + 1, config.train.max_denoising_steps, (1,) + ).item() + + height, width = ( + prompt_pair.resolution, + prompt_pair.resolution, + ) + if prompt_pair.dynamic_resolution: + height, width = train_util.get_random_resolution_in_bucket( + prompt_pair.resolution + ) + + if config.logging.verbose: + print("guidance_scale:", prompt_pair.guidance_scale) + print("resolution:", prompt_pair.resolution) + print("dynamic_resolution:", prompt_pair.dynamic_resolution) + if prompt_pair.dynamic_resolution: + print("bucketed resolution:", (height, width)) + print("batch_size:", prompt_pair.batch_size) + + latents = train_util.get_initial_latents( + noise_scheduler, prompt_pair.batch_size, height, width, 1 + ).to(device, dtype=weight_dtype) + + with network: + # ちょっとデノイズされれたものが返る + denoised_latents = train_util.diffusion( + unet, + noise_scheduler, + latents, # 単純なノイズのlatentsを渡す + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.target, + prompt_pair.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + noise_scheduler.set_timesteps(1000) + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / config.train.max_denoising_steps) + ] + + # with network: の外では空のLoRAのみが有効になる + positive_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.positive, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + + neutral_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.neutral, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + unconditional_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.unconditional, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + + + ######################### + if config.logging.verbose: + print("positive_latents:", positive_latents[0, 0, :5, :5]) + print("neutral_latents:", neutral_latents[0, 0, :5, :5]) + print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) + + with network: + target_latents = train_util.predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + train_util.concat_embeddings( + prompt_pair.unconditional, + prompt_pair.target, + prompt_pair.batch_size, + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + + ######################### + + if config.logging.verbose: + print("target_latents:", target_latents[0, 0, :5, :5]) + + positive_latents.requires_grad = False + neutral_latents.requires_grad = False + unconditional_latents.requires_grad = False + + loss = prompt_pair.loss( + target_latents=target_latents, + positive_latents=positive_latents, + neutral_latents=neutral_latents, + unconditional_latents=unconditional_latents, + ) + + # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない + pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}") + if config.logging.use_wandb: + wandb.log( + {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]} + ) + + loss.backward() + optimizer.step() + lr_scheduler.step() + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + target_latents, + latents, + ) + flush() + + if ( + i % config.save.per_steps == 0 + and i != 0 + and i != config.train.iterations - 1 + ): + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_{i}steps.pt", + dtype=save_weight_dtype, + ) + + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_last.pt", + dtype=save_weight_dtype, + ) + + del ( + unet, + noise_scheduler, + loss, + optimizer, + network, + ) + + flush() + + print("Done.") + + +def main(args): + config_file = args.config_file + + config = config_util.load_config_from_yaml(config_file) + if args.name is not None: + config.save.name = args.name + attributes = [] + if args.attributes is not None: + attributes = args.attributes.split(',') + attributes = [a.strip() for a in attributes] + + config.network.alpha = args.alpha + config.network.rank = args.rank + config.save.name += f'_alpha{args.alpha}' + config.save.name += f'_rank{config.network.rank }' + config.save.name += f'_{config.network.training_method}' + config.save.path += f'/{config.save.name}' + + prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes) + device = torch.device(f"cuda:{args.device}") + + train(config=config, prompts=prompts, device=device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_file", + required=False, + default = 'data/config.yaml', + help="Config file for training.", + ) + # config_file 'data/config.yaml' + parser.add_argument( + "--alpha", + type=float, + required=True, + help="LoRA weight.", + ) + # --alpha 1.0 + parser.add_argument( + "--rank", + type=int, + required=False, + help="Rank of LoRA.", + default=4, + ) + # --rank 4 + parser.add_argument( + "--device", + type=int, + required=False, + default=0, + help="Device to train on.", + ) + # --device 0 + parser.add_argument( + "--name", + type=str, + required=False, + default=None, + help="Device to train on.", + ) + # --name 'eyesize_slider' + parser.add_argument( + "--attributes", + type=str, + required=False, + default=None, + help="attritbutes to disentangle (comma seperated string)", + ) + + # --attributes 'male, female' + + args = parser.parse_args() + + main(args) diff --git a/trainscripts/textsliders/train_lora_xl.py b/trainscripts/textsliders/train_lora_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2e488da033eb181d52fe58a0d6fffdee406b56 --- /dev/null +++ b/trainscripts/textsliders/train_lora_xl.py @@ -0,0 +1,463 @@ +# ref: +# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 +# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py + +from typing import List, Optional +import argparse +import ast +from pathlib import Path +import gc + +import torch +from tqdm import tqdm + + +from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV +import train_util +import model_util +import prompt_util +from prompt_util import ( + PromptEmbedsCache, + PromptEmbedsPair, + PromptSettings, + PromptEmbedsXL, +) +import debug_util +import config_util +from config_util import RootConfig + +import wandb + +NUM_IMAGES_PER_PROMPT = 1 + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +def train( + config: RootConfig, + prompts: list[PromptSettings], + device, +): + metadata = { + "prompts": ",".join([prompt.json() for prompt in prompts]), + "config": config.json(), + } + save_path = Path(config.save.path) + + modules = DEFAULT_TARGET_REPLACE + if config.network.type == "c3lier": + modules += UNET_TARGET_REPLACE_MODULE_CONV + + if config.logging.verbose: + print(metadata) + + if config.logging.use_wandb: + wandb.init(project=f"LECO_{config.save.name}", config=metadata) + + weight_dtype = config_util.parse_precision(config.train.precision) + save_weight_dtype = config_util.parse_precision(config.train.precision) + + ( + tokenizers, + text_encoders, + unet, + noise_scheduler, + ) = model_util.load_models_xl( + config.pretrained_model.name_or_path, + scheduler_name=config.train.noise_scheduler, + ) + + for text_encoder in text_encoders: + text_encoder.to(device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + unet.to(device, dtype=weight_dtype) + if config.other.use_xformers: + unet.enable_xformers_memory_efficient_attention() + unet.requires_grad_(False) + unet.eval() + + network = LoRANetwork( + unet, + rank=config.network.rank, + multiplier=1.0, + alpha=config.network.alpha, + train_method=config.network.training_method, + ).to(device, dtype=weight_dtype) + + optimizer_module = train_util.get_optimizer(config.train.optimizer) + #optimizer_args + optimizer_kwargs = {} + if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: + for arg in config.train.optimizer_args.split(" "): + key, value = arg.split("=") + value = ast.literal_eval(value) + optimizer_kwargs[key] = value + + optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) + lr_scheduler = train_util.get_lr_scheduler( + config.train.lr_scheduler, + optimizer, + max_iterations=config.train.iterations, + lr_min=config.train.lr / 100, + ) + criteria = torch.nn.MSELoss() + + print("Prompts") + for settings in prompts: + print(settings) + + # debug + debug_util.check_requires_grad(network) + debug_util.check_training_mode(network) + + cache = PromptEmbedsCache() + prompt_pairs: list[PromptEmbedsPair] = [] + + with torch.no_grad(): + for settings in prompts: + print(settings) + for prompt in [ + settings.target, + settings.positive, + settings.neutral, + settings.unconditional, + ]: + if cache[prompt] == None: + tex_embs, pool_embs = train_util.encode_prompts_xl( + tokenizers, + text_encoders, + [prompt], + num_images_per_prompt=NUM_IMAGES_PER_PROMPT, + ) + cache[prompt] = PromptEmbedsXL( + tex_embs, + pool_embs + ) + + prompt_pairs.append( + PromptEmbedsPair( + criteria, + cache[settings.target], + cache[settings.positive], + cache[settings.unconditional], + cache[settings.neutral], + settings, + ) + ) + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + del tokenizer, text_encoder + + flush() + + pbar = tqdm(range(config.train.iterations)) + + loss = None + + for i in pbar: + with torch.no_grad(): + noise_scheduler.set_timesteps( + config.train.max_denoising_steps, device=device + ) + + optimizer.zero_grad() + + prompt_pair: PromptEmbedsPair = prompt_pairs[ + torch.randint(0, len(prompt_pairs), (1,)).item() + ] + + # 1 ~ 49 からランダム + timesteps_to = torch.randint( + 1, config.train.max_denoising_steps, (1,) + ).item() + + height, width = prompt_pair.resolution, prompt_pair.resolution + if prompt_pair.dynamic_resolution: + height, width = train_util.get_random_resolution_in_bucket( + prompt_pair.resolution + ) + + if config.logging.verbose: + print("gudance_scale:", prompt_pair.guidance_scale) + print("resolution:", prompt_pair.resolution) + print("dynamic_resolution:", prompt_pair.dynamic_resolution) + if prompt_pair.dynamic_resolution: + print("bucketed resolution:", (height, width)) + print("batch_size:", prompt_pair.batch_size) + print("dynamic_crops:", prompt_pair.dynamic_crops) + + latents = train_util.get_initial_latents( + noise_scheduler, prompt_pair.batch_size, height, width, 1 + ).to(device, dtype=weight_dtype) + + add_time_ids = train_util.get_add_time_ids( + height, + width, + dynamic_crops=prompt_pair.dynamic_crops, + dtype=weight_dtype, + ).to(device, dtype=weight_dtype) + + with network: + # ちょっとデノイズされれたものが返る + denoised_latents = train_util.diffusion_xl( + unet, + noise_scheduler, + latents, # 単純なノイズのlatentsを渡す + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.target.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.target.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + noise_scheduler.set_timesteps(1000) + + current_timestep = noise_scheduler.timesteps[ + int(timesteps_to * 1000 / config.train.max_denoising_steps) + ] + + # with network: の外では空のLoRAのみが有効になる + positive_latents = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.positive.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.positive.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + neutral_latents = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.neutral.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.neutral.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + unconditional_latents = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.unconditional.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.unconditional.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + + if config.logging.verbose: + print("positive_latents:", positive_latents[0, 0, :5, :5]) + print("neutral_latents:", neutral_latents[0, 0, :5, :5]) + print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) + + with network: + target_latents = train_util.predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.text_embeds, + prompt_pair.target.text_embeds, + prompt_pair.batch_size, + ), + add_text_embeddings=train_util.concat_embeddings( + prompt_pair.unconditional.pooled_embeds, + prompt_pair.target.pooled_embeds, + prompt_pair.batch_size, + ), + add_time_ids=train_util.concat_embeddings( + add_time_ids, add_time_ids, prompt_pair.batch_size + ), + guidance_scale=1, + ).to(device, dtype=weight_dtype) + + if config.logging.verbose: + print("target_latents:", target_latents[0, 0, :5, :5]) + + positive_latents.requires_grad = False + neutral_latents.requires_grad = False + unconditional_latents.requires_grad = False + + loss = prompt_pair.loss( + target_latents=target_latents, + positive_latents=positive_latents, + neutral_latents=neutral_latents, + unconditional_latents=unconditional_latents, + ) + + # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない + pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}") + if config.logging.use_wandb: + wandb.log( + {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]} + ) + + loss.backward() + optimizer.step() + lr_scheduler.step() + + del ( + positive_latents, + neutral_latents, + unconditional_latents, + target_latents, + latents, + ) + flush() + + if ( + i % config.save.per_steps == 0 + and i != 0 + and i != config.train.iterations - 1 + ): + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_{i}steps.pt", + dtype=save_weight_dtype, + ) + + print("Saving...") + save_path.mkdir(parents=True, exist_ok=True) + network.save_weights( + save_path / f"{config.save.name}_last.pt", + dtype=save_weight_dtype, + ) + + del ( + unet, + noise_scheduler, + loss, + optimizer, + network, + ) + + flush() + + print("Done.") + + +def main(args): + config_file = args.config_file + + config = config_util.load_config_from_yaml(config_file) + if args.name is not None: + config.save.name = args.name + attributes = [] + if args.attributes is not None: + attributes = args.attributes.split(',') + attributes = [a.strip() for a in attributes] + + config.network.alpha = args.alpha + config.network.rank = args.rank + config.save.name += f'_alpha{args.alpha}' + config.save.name += f'_rank{config.network.rank }' + config.save.name += f'_{config.network.training_method}' + config.save.path += f'/{config.save.name}' + + prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes) + + device = torch.device(f"cuda:{args.device}") + train(config, prompts, device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_file", + required=True, + help="Config file for training.", + ) + # config_file 'data/config.yaml' + parser.add_argument( + "--alpha", + type=float, + required=True, + help="LoRA weight.", + ) + # --alpha 1.0 + parser.add_argument( + "--rank", + type=int, + required=False, + help="Rank of LoRA.", + default=4, + ) + # --rank 4 + parser.add_argument( + "--device", + type=int, + required=False, + default=0, + help="Device to train on.", + ) + # --device 0 + parser.add_argument( + "--name", + type=str, + required=False, + default=None, + help="Device to train on.", + ) + # --name 'eyesize_slider' + parser.add_argument( + "--attributes", + type=str, + required=False, + default=None, + help="attritbutes to disentangle (comma seperated string)", + ) + + args = parser.parse_args() + + main(args) diff --git a/trainscripts/textsliders/train_util.py b/trainscripts/textsliders/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec76071281ea99459781f4a66ed2d7555f33a5f --- /dev/null +++ b/trainscripts/textsliders/train_util.py @@ -0,0 +1,419 @@ +from typing import Optional, Union + +import torch + +from transformers import CLIPTextModel, CLIPTokenizer +from diffusers import UNet2DConditionModel, SchedulerMixin + +from model_util import SDXL_TEXT_ENCODER_TYPE + +from tqdm import tqdm + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + +UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL +TEXT_ENCODER_2_PROJECTION_DIM = 1280 +UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 + + +def get_random_noise( + batch_size: int, height: int, width: int, generator: torch.Generator = None +) -> torch.Tensor: + return torch.randn( + ( + batch_size, + UNET_IN_CHANNELS, + height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや + width // VAE_SCALE_FACTOR, + ), + generator=generator, + device="cpu", + ) + + +# https://www.crosslabs.org/blog/diffusion-with-offset-noise +def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float): + latents = latents + noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + return latents + + +def get_initial_latents( + scheduler: SchedulerMixin, + n_imgs: int, + height: int, + width: int, + n_prompts: int, + generator=None, +) -> torch.Tensor: + noise = get_random_noise(n_imgs, height, width, generator=generator).repeat( + n_prompts, 1, 1, 1 + ) + + latents = noise * scheduler.init_noise_sigma + + return latents + + +def text_tokenize( + tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ! + prompts: list[str], +): + return tokenizer( + prompts, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + +def text_encode(text_encoder: CLIPTextModel, tokens): + return text_encoder(tokens.to(text_encoder.device))[0] + + +def encode_prompts( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTokenizer, + prompts: list[str], +): + + text_tokens = text_tokenize(tokenizer, prompts) + text_embeddings = text_encode(text_encoder, text_tokens) + + + + return text_embeddings + + +# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 +def text_encode_xl( + text_encoder: SDXL_TEXT_ENCODER_TYPE, + tokens: torch.FloatTensor, + num_images_per_prompt: int = 1, +): + prompt_embeds = text_encoder( + tokens.to(text_encoder.device), output_hidden_states=True + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, pooled_prompt_embeds + + +def encode_prompts_xl( + tokenizers: list[CLIPTokenizer], + text_encoders: list[SDXL_TEXT_ENCODER_TYPE], + prompts: list[str], + num_images_per_prompt: int = 1, +) -> tuple[torch.FloatTensor, torch.FloatTensor]: + # text_encoder and text_encoder_2's penuultimate layer's output + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_tokens_input_ids = text_tokenize(tokenizer, prompts) + text_embeds, pooled_text_embeds = text_encode_xl( + text_encoder, text_tokens_input_ids, num_images_per_prompt + ) + + text_embeds_list.append(text_embeds) + + bs_embed = pooled_text_embeds.shape[0] + pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds + + +def concat_embeddings( + unconditional: torch.FloatTensor, + conditional: torch.FloatTensor, + n_imgs: int, +): + return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) + + +# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721 +def predict_noise( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + timestep: int, # 現在のタイムステップ + latents: torch.FloatTensor, + text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの + guidance_scale=7.5, +) -> torch.FloatTensor: + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guided_target = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + return guided_target + + +# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 +@torch.no_grad() +def diffusion( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + latents: torch.FloatTensor, # ただのノイズだけのlatents + text_embeddings: torch.FloatTensor, + total_timesteps: int = 1000, + start_timesteps=0, + **kwargs, +): + # latents_steps = [] + + for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]): + noise_pred = predict_noise( + unet, scheduler, timestep, latents, text_embeddings, **kwargs + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + + # return latents_steps + return latents + + +def rescale_noise_cfg( + noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0 +): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + + return noise_cfg + + +def predict_noise_xl( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + timestep: int, # 現在のタイムステップ + latents: torch.FloatTensor, + text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの + add_text_embeddings: torch.FloatTensor, # pooled なやつ + add_time_ids: torch.FloatTensor, + guidance_scale=7.5, + guidance_rescale=0.7, +) -> torch.FloatTensor: + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + added_cond_kwargs = { + "text_embeds": add_text_embeddings, + "time_ids": add_time_ids, + } + + # predict the noise residual + noise_pred = unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guided_target = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) + + return guided_target + + +@torch.no_grad() +def diffusion_xl( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + latents: torch.FloatTensor, # ただのノイズだけのlatents + text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor], + add_text_embeddings: torch.FloatTensor, # pooled なやつ + add_time_ids: torch.FloatTensor, + guidance_scale: float = 1.0, + total_timesteps: int = 1000, + start_timesteps=0, +): + # latents_steps = [] + + for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]): + noise_pred = predict_noise_xl( + unet, + scheduler, + timestep, + latents, + text_embeddings, + add_text_embeddings, + add_time_ids, + guidance_scale=guidance_scale, + guidance_rescale=0.7, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + + # return latents_steps + return latents + + +# for XL +def get_add_time_ids( + height: int, + width: int, + dynamic_crops: bool = False, + dtype: torch.dtype = torch.float32, +): + if dynamic_crops: + # random float scale between 1 and 3 + random_scale = torch.rand(1).item() * 2 + 1 + original_size = (int(height * random_scale), int(width * random_scale)) + # random position + crops_coords_top_left = ( + torch.randint(0, original_size[0] - height, (1,)).item(), + torch.randint(0, original_size[1] - width, (1,)).item(), + ) + target_size = (height, width) + else: + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + + # this is expected as 6 + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + # this is expected as 2816 + passed_add_embed_dim = ( + UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 + + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 + ) + if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: + raise ValueError( + f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + +def get_optimizer(name: str): + name = name.lower() + + if name.startswith("dadapt"): + import dadaptation + + if name == "dadaptadam": + return dadaptation.DAdaptAdam + elif name == "dadaptlion": + return dadaptation.DAdaptLion + else: + raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion") + + elif name.endswith("8bit"): # 検証してない + import bitsandbytes as bnb + + if name == "adam8bit": + return bnb.optim.Adam8bit + elif name == "lion8bit": + return bnb.optim.Lion8bit + else: + raise ValueError("8bit optimizer must be adam8bit or lion8bit") + + else: + if name == "adam": + return torch.optim.Adam + elif name == "adamw": + return torch.optim.AdamW + elif name == "lion": + from lion_pytorch import Lion + + return Lion + elif name == "prodigy": + import prodigyopt + + return prodigyopt.Prodigy + else: + raise ValueError("Optimizer must be adam, adamw, lion or Prodigy") + + +def get_lr_scheduler( + name: Optional[str], + optimizer: torch.optim.Optimizer, + max_iterations: Optional[int], + lr_min: Optional[float], + **kwargs, +): + if name == "cosine": + return torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs + ) + elif name == "cosine_with_restarts": + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs + ) + elif name == "step": + return torch.optim.lr_scheduler.StepLR( + optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs + ) + elif name == "constant": + return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs) + elif name == "linear": + return torch.optim.lr_scheduler.LinearLR( + optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs + ) + else: + raise ValueError( + "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" + ) + + +def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]: + max_resolution = bucket_resolution + min_resolution = bucket_resolution // 2 + + step = 64 + + min_step = min_resolution // step + max_step = max_resolution // step + + height = torch.randint(min_step, max_step, (1,)).item() * step + width = torch.randint(min_step, max_step, (1,)).item() * step + + return height, width diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fadcc6e527ef6a24af0668e72053cbafe37447 --- /dev/null +++ b/utils.py @@ -0,0 +1,391 @@ +import torch +from PIL import Image +import argparse +import os, json, random +import pandas as pd +import matplotlib.pyplot as plt +import glob, re + +from safetensors.torch import load_file +import matplotlib.image as mpimg +import copy +import gc +from transformers import CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import DiffusionPipeline +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor +from typing import Any, Dict, List, Optional, Tuple, Union +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +import inspect +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from diffusers.pipelines import StableDiffusionXLPipeline +import random + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +def flush(): + torch.cuda.empty_cache() + gc.collect() + +@torch.no_grad() +def call( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + + network=None, + start_noise=None, + scale=None, + unet=None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + latents = latents.to(unet.dtype) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if t>start_noise: + network.set_lora_slider(scale=0) + else: + network.set_lora_slider(scale=scale) + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + with network: + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models +# self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image)