Spaces:
Runtime error
Runtime error
RohitGandikota
commited on
Commit
โข
1f8beea
1
Parent(s):
cb9665a
testing layout
Browse filesThis view is limited to 50 files because it contains too many changes. ย
See raw diff
- .gitignore +1 -0
- __init__.py +1 -0
- app.py +261 -4
- models/age.pt +3 -0
- models/cartoon_style.pt +3 -0
- models/chubby.pt +3 -0
- models/clay_style.pt +3 -0
- models/cluttered_room.pt +3 -0
- models/curlyhair.pt +3 -0
- models/dark_weather.pt +3 -0
- models/eyebrow.pt +3 -0
- models/eyesize.pt +3 -0
- models/festive.pt +3 -0
- models/fix_hands.pt +3 -0
- models/long_hair.pt +3 -0
- models/muscular.pt +3 -0
- models/pixar_style.pt +3 -0
- models/professional.pt +3 -0
- models/repair_slider.pt +3 -0
- models/sculpture_style.pt +3 -0
- models/smiling.pt +3 -0
- models/stylegan_latent1.pt +3 -0
- models/stylegan_latent2.pt +3 -0
- models/suprised_look.pt +3 -0
- models/tropical_weather.pt +3 -0
- models/winter_weather.pt +3 -0
- requirements.txt โ reqs.txt +0 -0
- trainscripts/__init__.py +1 -0
- trainscripts/imagesliders/config_util.py +104 -0
- trainscripts/imagesliders/data/config-xl.yaml +28 -0
- trainscripts/imagesliders/data/config.yaml +28 -0
- trainscripts/imagesliders/data/prompts-xl.yaml +275 -0
- trainscripts/imagesliders/data/prompts.yaml +174 -0
- trainscripts/imagesliders/debug_util.py +16 -0
- trainscripts/imagesliders/lora.py +256 -0
- trainscripts/imagesliders/model_util.py +283 -0
- trainscripts/imagesliders/prompt_util.py +174 -0
- trainscripts/imagesliders/train_lora-scale-xl.py +548 -0
- trainscripts/imagesliders/train_lora-scale.py +501 -0
- trainscripts/imagesliders/train_util.py +458 -0
- trainscripts/textsliders/__init__.py +0 -0
- trainscripts/textsliders/config_util.py +104 -0
- trainscripts/textsliders/data/config-xl.yaml +28 -0
- trainscripts/textsliders/data/config.yaml +28 -0
- trainscripts/textsliders/data/prompts-xl.yaml +477 -0
- trainscripts/textsliders/data/prompts.yaml +193 -0
- trainscripts/textsliders/debug_util.py +16 -0
- trainscripts/textsliders/flush.py +5 -0
- trainscripts/textsliders/generate_images_xl.py +513 -0
- trainscripts/textsliders/lora.py +258 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from trainscripts.textsliders import lora
|
app.py
CHANGED
@@ -1,7 +1,264 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from utils import call
|
5 |
+
from diffusers.pipelines import StableDiffusionXLPipeline
|
6 |
+
StableDiffusionXLPipeline.__call__ = call
|
7 |
|
8 |
+
model_map = {'Age' : 'models/age.pt',
|
9 |
+
'Chubby': 'models/chubby.pt',
|
10 |
+
'Muscular': 'models/muscular.pt',
|
11 |
+
'Wavy Eyebrows': 'models/eyebrows.pt',
|
12 |
+
'Small Eyes': 'models/eyesize.pt',
|
13 |
+
'Long Hair' : 'models/longhair.pt',
|
14 |
+
'Curly Hair' : 'models/curlyhair.pt',
|
15 |
+
'Smiling' : 'models/smiling.pt',
|
16 |
+
'Pixar Style' : 'models/pixar_style.pt',
|
17 |
+
'Sculpture Style': 'models/sculpture_style.pt',
|
18 |
+
'Repair Images': 'models/repair_slider.pt',
|
19 |
+
'Fix Hands': 'models/fix_hands.pt',
|
20 |
+
}
|
21 |
|
22 |
+
ORIGINAL_SPACE_ID = 'baulab/ConceptSliders'
|
23 |
+
SPACE_ID = os.getenv('SPACE_ID')
|
24 |
+
|
25 |
+
SHARED_UI_WARNING = f'''## Attention - Training does not work in this shared UI. You can either duplicate and use it with a gpu with at least 40GB, or clone this repository to run on your own machine.
|
26 |
+
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
|
27 |
+
'''
|
28 |
+
|
29 |
+
|
30 |
+
class Demo:
|
31 |
+
|
32 |
+
def __init__(self) -> None:
|
33 |
+
|
34 |
+
self.training = False
|
35 |
+
self.generating = False
|
36 |
+
self.device = 'cuda'
|
37 |
+
self.weight_dtype = torch.float16
|
38 |
+
self.pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=weight_dtype)
|
39 |
+
|
40 |
+
with gr.Blocks() as demo:
|
41 |
+
self.layout()
|
42 |
+
demo.queue(concurrency_count=5).launch()
|
43 |
+
|
44 |
+
|
45 |
+
def layout(self):
|
46 |
+
|
47 |
+
with gr.Row():
|
48 |
+
|
49 |
+
if SPACE_ID == ORIGINAL_SPACE_ID:
|
50 |
+
|
51 |
+
self.warning = gr.Markdown(SHARED_UI_WARNING)
|
52 |
+
|
53 |
+
with gr.Row():
|
54 |
+
|
55 |
+
with gr.Tab("Test") as inference_column:
|
56 |
+
|
57 |
+
with gr.Row():
|
58 |
+
|
59 |
+
self.explain_infr = gr.Markdown(interactive=False,
|
60 |
+
value='This is a demo of [Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models](https://sliders.baulab.info/). To try out a model that can control a particular concept, select a model and enter any prompt. For example, if you select the model "Surprised Look" you can generate images for the prompt "A picture of a person, realistic, 8k" and compare the slider effect to the image generated by original model. We have also provided several other pre-fine-tuned models like "repair" sliders to repair flaws in SDXL generated images (Check out the "Pretrained Sliders" drop-down). You can also train and run your own custom sliders. Check out the "train" section for custom concept slider training.')
|
61 |
+
|
62 |
+
with gr.Row():
|
63 |
+
|
64 |
+
with gr.Column(scale=1):
|
65 |
+
|
66 |
+
self.prompt_input_infr = gr.Text(
|
67 |
+
placeholder="Enter prompt...",
|
68 |
+
label="Prompt",
|
69 |
+
info="Prompt to generate"
|
70 |
+
)
|
71 |
+
|
72 |
+
with gr.Row():
|
73 |
+
|
74 |
+
self.model_dropdown = gr.Dropdown(
|
75 |
+
label="Pretrained Sliders",
|
76 |
+
choices= list(model_map.keys()),
|
77 |
+
value='Age',
|
78 |
+
interactive=True
|
79 |
+
)
|
80 |
+
|
81 |
+
self.seed_infr = gr.Number(
|
82 |
+
label="Seed",
|
83 |
+
value=12345
|
84 |
+
)
|
85 |
+
|
86 |
+
with gr.Column(scale=2):
|
87 |
+
|
88 |
+
self.infr_button = gr.Button(
|
89 |
+
value="Generate",
|
90 |
+
interactive=True
|
91 |
+
)
|
92 |
+
|
93 |
+
with gr.Row():
|
94 |
+
|
95 |
+
self.image_new = gr.Image(
|
96 |
+
label="Slider",
|
97 |
+
interactive=False
|
98 |
+
)
|
99 |
+
self.image_orig = gr.Image(
|
100 |
+
label="Original SD",
|
101 |
+
interactive=False
|
102 |
+
)
|
103 |
+
|
104 |
+
with gr.Tab("Train") as training_column:
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
|
108 |
+
self.explain_train= gr.Markdown(interactive=False,
|
109 |
+
value='In this part you can train a concept slider for Stable Diffusion XL. Enter a target concept you wish to make an edit on. Next, enter a enhance prompt of the attribute you wish to edit (for controlling age of a person, enter "person, old"). Then, type the supress prompt of the attribute (for our example, enter "person, young"). Then press "train" button. With default settings, it takes about 15 minutes to train a slider; then you can try inference above or download the weights. Code and details are at [github link](https://github.com/rohitgandikota/sliders).')
|
110 |
+
|
111 |
+
with gr.Row():
|
112 |
+
|
113 |
+
with gr.Column(scale=3):
|
114 |
+
|
115 |
+
self.target_concept = gr.Text(
|
116 |
+
placeholder="Enter target concept to make edit on ...",
|
117 |
+
label="Prompt of concept on which edit is made",
|
118 |
+
info="Prompt corresponding to concept to edit"
|
119 |
+
)
|
120 |
+
|
121 |
+
self.positive_prompt = gr.Text(
|
122 |
+
placeholder="Enter the enhance prompt for the edit...",
|
123 |
+
label="Prompt to enhance",
|
124 |
+
info="Prompt corresponding to concept to enhance"
|
125 |
+
)
|
126 |
+
|
127 |
+
self.negative_prompt = gr.Text(
|
128 |
+
placeholder="Enter the suppress prompt for the edit...",
|
129 |
+
label="Prompt to suppress",
|
130 |
+
info="Prompt corresponding to concept to supress"
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
self.rank = gr.Number(
|
135 |
+
value=4,
|
136 |
+
label="Rank of the Slider",
|
137 |
+
info='Slider Rank to train'
|
138 |
+
)
|
139 |
+
|
140 |
+
self.iterations_input = gr.Number(
|
141 |
+
value=1000,
|
142 |
+
precision=0,
|
143 |
+
label="Iterations",
|
144 |
+
info='iterations used to train'
|
145 |
+
)
|
146 |
+
|
147 |
+
self.lr_input = gr.Number(
|
148 |
+
value=2e-4,
|
149 |
+
label="Learning Rate",
|
150 |
+
info='Learning rate used to train'
|
151 |
+
)
|
152 |
+
|
153 |
+
with gr.Column(scale=1):
|
154 |
+
|
155 |
+
self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
|
156 |
+
|
157 |
+
self.train_button = gr.Button(
|
158 |
+
value="Train",
|
159 |
+
)
|
160 |
+
|
161 |
+
self.download = gr.Files()
|
162 |
+
|
163 |
+
self.infr_button.click(self.inference, inputs = [
|
164 |
+
self.prompt_input_infr,
|
165 |
+
self.seed_infr,
|
166 |
+
self.model_dropdown
|
167 |
+
],
|
168 |
+
outputs=[
|
169 |
+
self.image_new,
|
170 |
+
self.image_orig
|
171 |
+
]
|
172 |
+
)
|
173 |
+
self.train_button.click(self.train, inputs = [
|
174 |
+
self.target_concept,
|
175 |
+
self.positive_prompt,
|
176 |
+
slef.negative_prompt,
|
177 |
+
self.rank,
|
178 |
+
self.iterations_input,
|
179 |
+
self.lr_input
|
180 |
+
],
|
181 |
+
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
182 |
+
)
|
183 |
+
|
184 |
+
def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
185 |
+
|
186 |
+
if self.training:
|
187 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
188 |
+
|
189 |
+
if train_method == 'ESD-x':
|
190 |
+
|
191 |
+
modules = ".*attn2$"
|
192 |
+
frozen = []
|
193 |
+
|
194 |
+
elif train_method == 'ESD-u':
|
195 |
+
|
196 |
+
modules = "unet$"
|
197 |
+
frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"]
|
198 |
+
|
199 |
+
elif train_method == 'ESD-self':
|
200 |
+
|
201 |
+
modules = ".*attn1$"
|
202 |
+
frozen = []
|
203 |
+
|
204 |
+
randn = torch.randint(1, 10000000, (1,)).item()
|
205 |
+
|
206 |
+
save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
|
207 |
+
|
208 |
+
self.training = True
|
209 |
+
|
210 |
+
train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
|
211 |
+
|
212 |
+
self.training = False
|
213 |
+
|
214 |
+
torch.cuda.empty_cache()
|
215 |
+
|
216 |
+
model_map['Custom'] = save_path
|
217 |
+
|
218 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
219 |
+
|
220 |
+
|
221 |
+
def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
|
222 |
+
|
223 |
+
seed = seed or 12345
|
224 |
+
|
225 |
+
generator = torch.manual_seed(seed)
|
226 |
+
|
227 |
+
model_path = model_map[model_name]
|
228 |
+
|
229 |
+
checkpoint = torch.load(model_path)
|
230 |
+
|
231 |
+
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
|
232 |
+
|
233 |
+
torch.cuda.empty_cache()
|
234 |
+
|
235 |
+
images = self.diffuser(
|
236 |
+
prompt,
|
237 |
+
n_steps=50,
|
238 |
+
generator=generator
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
orig_image = images[0][0]
|
243 |
+
|
244 |
+
torch.cuda.empty_cache()
|
245 |
+
|
246 |
+
generator = torch.manual_seed(seed)
|
247 |
+
|
248 |
+
with finetuner:
|
249 |
+
|
250 |
+
images = self.diffuser(
|
251 |
+
prompt,
|
252 |
+
n_steps=50,
|
253 |
+
generator=generator
|
254 |
+
)
|
255 |
+
|
256 |
+
edited_image = images[0][0]
|
257 |
+
|
258 |
+
del finetuner
|
259 |
+
torch.cuda.empty_cache()
|
260 |
+
|
261 |
+
return edited_image, orig_image
|
262 |
+
|
263 |
+
|
264 |
+
demo = Demo()
|
models/age.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c1c096f7cc1109b4072cbc604c811a5f0ff034fc0f6dc7cf66a558550aa4890
|
3 |
+
size 9142347
|
models/cartoon_style.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e07c30e4f82f709a474ae11dc5108ac48f81b6996b937757c8dd198920ea9b4d
|
3 |
+
size 9146507
|
models/chubby.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a70fb34187821a06a39bf36baa400090a32758d56771c3f54fcc4d9089f0d88
|
3 |
+
size 9144427
|
models/clay_style.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b0deeb787248811fb8e54498768e303cffaeb3125d00c5fd303294af59a9380
|
3 |
+
size 9143387
|
models/cluttered_room.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee409a45bfaa7ca01fbffe63ec185c0f5ccf0e7b0fa67070a9e0b41886b7ea66
|
3 |
+
size 9140267
|
models/curlyhair.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d9b8d7d44da256291e3710f74954d352160ade5cbe291bce16c8f4951db31e7b
|
3 |
+
size 9136043
|
models/dark_weather.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eecd2ae8b35022cbfb9c32637d9fa8c3c0ca3aa5ea189369c027f938064ada3c
|
3 |
+
size 9135003
|
models/eyebrow.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:442770d2c30de92e30a1c2fcf9aab6b6cf5a3786eff84d513b7455345c79b57d
|
3 |
+
size 9135003
|
models/eyesize.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8fdffa3e7788f4bd6be9a2fe3b91957b4f35999fc9fa19eabfb49f92fbf6650b
|
3 |
+
size 9139227
|
models/festive.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:70d6c5d5be5f001510988852c2d233a916d23766675d9a000c6f785cd7e9127c
|
3 |
+
size 9133963
|
models/fix_hands.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d98c4828468c8d5831c439f49914672710f63219a561b191670fa54d542fa57b
|
3 |
+
size 9131883
|
models/long_hair.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e93dba27fa012bba0ea468eb2f9877ec0934424a9474e30ac9e94ea0517822ca
|
3 |
+
size 9147547
|
models/muscular.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3b46b8eeac992f2d0e76ce887ea45ec1ce70bfbae053876de26d1f33f986eb37
|
3 |
+
size 9135003
|
models/pixar_style.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e07c30e4f82f709a474ae11dc5108ac48f81b6996b937757c8dd198920ea9b4d
|
3 |
+
size 9146507
|
models/professional.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2d4289f4c60dd008fe487369ddccf3492bd678cc1e6b30de2c17f9ce802b12ac
|
3 |
+
size 9151707
|
models/repair_slider.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6e589e7d3b2174bb1d5d861a7218c4c26a94425b6dcdce0085b57f87ab841c5
|
3 |
+
size 9133963
|
models/sculpture_style.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2779746c08062ccb128fdaa6cb66f061070ac8f19386701a99fb9291392d5985
|
3 |
+
size 9148587
|
models/smiling.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6430ab47393ba15222ea0988c3479f547c8b59f93a41024bcddd7121ef7147d1
|
3 |
+
size 9146507
|
models/stylegan_latent1.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dca6cda8028af4587968cfed07c3bc6a2e79e5f8d01dad9351877f9de28f232d
|
3 |
+
size 9142347
|
models/stylegan_latent2.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4bbe239c399a4fc7b73a034b643c406106cd4c8392ad806ee3fd8dd8c80ba5fc
|
3 |
+
size 9142347
|
models/suprised_look.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36806271ca61dced2a506430c6c0b53ace09c68f65a90e09778c2bb5bcad31d4
|
3 |
+
size 9148587
|
models/tropical_weather.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:215e5445bbb7288ebea2e523181ca6db991417deca2736de29f0c3a76eb69ac0
|
3 |
+
size 9135003
|
models/winter_weather.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38f0bc81bc3cdef0c1c47895df6c9f0a9b98507f48928ef971f341e02c76bb4c
|
3 |
+
size 9132923
|
requirements.txt โ reqs.txt
RENAMED
File without changes
|
trainscripts/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from textsliders import lora
|
trainscripts/imagesliders/config_util.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from lora import TRAINING_METHODS
|
9 |
+
|
10 |
+
PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
|
11 |
+
NETWORK_TYPES = Literal["lierla", "c3lier"]
|
12 |
+
|
13 |
+
|
14 |
+
class PretrainedModelConfig(BaseModel):
|
15 |
+
name_or_path: str
|
16 |
+
v2: bool = False
|
17 |
+
v_pred: bool = False
|
18 |
+
|
19 |
+
clip_skip: Optional[int] = None
|
20 |
+
|
21 |
+
|
22 |
+
class NetworkConfig(BaseModel):
|
23 |
+
type: NETWORK_TYPES = "lierla"
|
24 |
+
rank: int = 4
|
25 |
+
alpha: float = 1.0
|
26 |
+
|
27 |
+
training_method: TRAINING_METHODS = "full"
|
28 |
+
|
29 |
+
|
30 |
+
class TrainConfig(BaseModel):
|
31 |
+
precision: PRECISION_TYPES = "bfloat16"
|
32 |
+
noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
|
33 |
+
|
34 |
+
iterations: int = 500
|
35 |
+
lr: float = 1e-4
|
36 |
+
optimizer: str = "adamw"
|
37 |
+
optimizer_args: str = ""
|
38 |
+
lr_scheduler: str = "constant"
|
39 |
+
|
40 |
+
max_denoising_steps: int = 50
|
41 |
+
|
42 |
+
|
43 |
+
class SaveConfig(BaseModel):
|
44 |
+
name: str = "untitled"
|
45 |
+
path: str = "./output"
|
46 |
+
per_steps: int = 200
|
47 |
+
precision: PRECISION_TYPES = "float32"
|
48 |
+
|
49 |
+
|
50 |
+
class LoggingConfig(BaseModel):
|
51 |
+
use_wandb: bool = False
|
52 |
+
|
53 |
+
verbose: bool = False
|
54 |
+
|
55 |
+
|
56 |
+
class OtherConfig(BaseModel):
|
57 |
+
use_xformers: bool = False
|
58 |
+
|
59 |
+
|
60 |
+
class RootConfig(BaseModel):
|
61 |
+
prompts_file: str
|
62 |
+
pretrained_model: PretrainedModelConfig
|
63 |
+
|
64 |
+
network: NetworkConfig
|
65 |
+
|
66 |
+
train: Optional[TrainConfig]
|
67 |
+
|
68 |
+
save: Optional[SaveConfig]
|
69 |
+
|
70 |
+
logging: Optional[LoggingConfig]
|
71 |
+
|
72 |
+
other: Optional[OtherConfig]
|
73 |
+
|
74 |
+
|
75 |
+
def parse_precision(precision: str) -> torch.dtype:
|
76 |
+
if precision == "fp32" or precision == "float32":
|
77 |
+
return torch.float32
|
78 |
+
elif precision == "fp16" or precision == "float16":
|
79 |
+
return torch.float16
|
80 |
+
elif precision == "bf16" or precision == "bfloat16":
|
81 |
+
return torch.bfloat16
|
82 |
+
|
83 |
+
raise ValueError(f"Invalid precision type: {precision}")
|
84 |
+
|
85 |
+
|
86 |
+
def load_config_from_yaml(config_path: str) -> RootConfig:
|
87 |
+
with open(config_path, "r") as f:
|
88 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
89 |
+
|
90 |
+
root = RootConfig(**config)
|
91 |
+
|
92 |
+
if root.train is None:
|
93 |
+
root.train = TrainConfig()
|
94 |
+
|
95 |
+
if root.save is None:
|
96 |
+
root.save = SaveConfig()
|
97 |
+
|
98 |
+
if root.logging is None:
|
99 |
+
root.logging = LoggingConfig()
|
100 |
+
|
101 |
+
if root.other is None:
|
102 |
+
root.other = OtherConfig()
|
103 |
+
|
104 |
+
return root
|
trainscripts/imagesliders/data/config-xl.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompts_file: "trainscripts/imagesliders/data/prompts-xl.yaml"
|
2 |
+
pretrained_model:
|
3 |
+
name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
|
4 |
+
v2: false # true if model is v2.x
|
5 |
+
v_pred: false # true if model uses v-prediction
|
6 |
+
network:
|
7 |
+
type: "c3lier" # or "c3lier" or "lierla"
|
8 |
+
rank: 4
|
9 |
+
alpha: 1.0
|
10 |
+
training_method: "noxattn"
|
11 |
+
train:
|
12 |
+
precision: "bfloat16"
|
13 |
+
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
14 |
+
iterations: 1000
|
15 |
+
lr: 0.0002
|
16 |
+
optimizer: "AdamW"
|
17 |
+
lr_scheduler: "constant"
|
18 |
+
max_denoising_steps: 50
|
19 |
+
save:
|
20 |
+
name: "temp"
|
21 |
+
path: "./models"
|
22 |
+
per_steps: 500
|
23 |
+
precision: "bfloat16"
|
24 |
+
logging:
|
25 |
+
use_wandb: false
|
26 |
+
verbose: false
|
27 |
+
other:
|
28 |
+
use_xformers: true
|
trainscripts/imagesliders/data/config.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompts_file: "trainscripts/imagesliders/data/prompts.yaml"
|
2 |
+
pretrained_model:
|
3 |
+
name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
|
4 |
+
v2: false # true if model is v2.x
|
5 |
+
v_pred: false # true if model uses v-prediction
|
6 |
+
network:
|
7 |
+
type: "c3lier" # or "c3lier" or "lierla"
|
8 |
+
rank: 4
|
9 |
+
alpha: 1.0
|
10 |
+
training_method: "noxattn"
|
11 |
+
train:
|
12 |
+
precision: "bfloat16"
|
13 |
+
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
14 |
+
iterations: 1000
|
15 |
+
lr: 0.0002
|
16 |
+
optimizer: "AdamW"
|
17 |
+
lr_scheduler: "constant"
|
18 |
+
max_denoising_steps: 50
|
19 |
+
save:
|
20 |
+
name: "temp"
|
21 |
+
path: "./models"
|
22 |
+
per_steps: 500
|
23 |
+
precision: "bfloat16"
|
24 |
+
logging:
|
25 |
+
use_wandb: false
|
26 |
+
verbose: false
|
27 |
+
other:
|
28 |
+
use_xformers: true
|
trainscripts/imagesliders/data/prompts-xl.yaml
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
####################################################################################################### AGE SLIDER
|
2 |
+
# - target: "male person" # what word for erasing the positive concept from
|
3 |
+
# positive: "male person, very old" # concept to erase
|
4 |
+
# unconditional: "male person, very young" # word to take the difference from the positive concept
|
5 |
+
# neutral: "male person" # starting point for conditioning the target
|
6 |
+
# action: "enhance" # erase or enhance
|
7 |
+
# guidance_scale: 4
|
8 |
+
# resolution: 512
|
9 |
+
# dynamic_resolution: false
|
10 |
+
# batch_size: 1
|
11 |
+
# - target: "female person" # what word for erasing the positive concept from
|
12 |
+
# positive: "female person, very old" # concept to erase
|
13 |
+
# unconditional: "female person, very young" # word to take the difference from the positive concept
|
14 |
+
# neutral: "female person" # starting point for conditioning the target
|
15 |
+
# action: "enhance" # erase or enhance
|
16 |
+
# guidance_scale: 4
|
17 |
+
# resolution: 512
|
18 |
+
# dynamic_resolution: false
|
19 |
+
# batch_size: 1
|
20 |
+
####################################################################################################### GLASSES SLIDER
|
21 |
+
# - target: "male person" # what word for erasing the positive concept from
|
22 |
+
# positive: "male person, wearing glasses" # concept to erase
|
23 |
+
# unconditional: "male person" # word to take the difference from the positive concept
|
24 |
+
# neutral: "male person" # starting point for conditioning the target
|
25 |
+
# action: "enhance" # erase or enhance
|
26 |
+
# guidance_scale: 4
|
27 |
+
# resolution: 512
|
28 |
+
# dynamic_resolution: false
|
29 |
+
# batch_size: 1
|
30 |
+
# - target: "female person" # what word for erasing the positive concept from
|
31 |
+
# positive: "female person, wearing glasses" # concept to erase
|
32 |
+
# unconditional: "female person" # word to take the difference from the positive concept
|
33 |
+
# neutral: "female person" # starting point for conditioning the target
|
34 |
+
# action: "enhance" # erase or enhance
|
35 |
+
# guidance_scale: 4
|
36 |
+
# resolution: 512
|
37 |
+
# dynamic_resolution: false
|
38 |
+
# batch_size: 1
|
39 |
+
####################################################################################################### ASTRONAUGHT SLIDER
|
40 |
+
# - target: "astronaught" # what word for erasing the positive concept from
|
41 |
+
# positive: "astronaught, with orange colored spacesuit" # concept to erase
|
42 |
+
# unconditional: "astronaught" # word to take the difference from the positive concept
|
43 |
+
# neutral: "astronaught" # starting point for conditioning the target
|
44 |
+
# action: "enhance" # erase or enhance
|
45 |
+
# guidance_scale: 4
|
46 |
+
# resolution: 512
|
47 |
+
# dynamic_resolution: false
|
48 |
+
# batch_size: 1
|
49 |
+
####################################################################################################### SMILING SLIDER
|
50 |
+
# - target: "male person" # what word for erasing the positive concept from
|
51 |
+
# positive: "male person, smiling" # concept to erase
|
52 |
+
# unconditional: "male person, frowning" # word to take the difference from the positive concept
|
53 |
+
# neutral: "male person" # starting point for conditioning the target
|
54 |
+
# action: "enhance" # erase or enhance
|
55 |
+
# guidance_scale: 4
|
56 |
+
# resolution: 512
|
57 |
+
# dynamic_resolution: false
|
58 |
+
# batch_size: 1
|
59 |
+
# - target: "female person" # what word for erasing the positive concept from
|
60 |
+
# positive: "female person, smiling" # concept to erase
|
61 |
+
# unconditional: "female person, frowning" # word to take the difference from the positive concept
|
62 |
+
# neutral: "female person" # starting point for conditioning the target
|
63 |
+
# action: "enhance" # erase or enhance
|
64 |
+
# guidance_scale: 4
|
65 |
+
# resolution: 512
|
66 |
+
# dynamic_resolution: false
|
67 |
+
# batch_size: 1
|
68 |
+
####################################################################################################### CAR COLOR SLIDER
|
69 |
+
# - target: "car" # what word for erasing the positive concept from
|
70 |
+
# positive: "car, white color" # concept to erase
|
71 |
+
# unconditional: "car, black color" # word to take the difference from the positive concept
|
72 |
+
# neutral: "car" # starting point for conditioning the target
|
73 |
+
# action: "enhance" # erase or enhance
|
74 |
+
# guidance_scale: 4
|
75 |
+
# resolution: 512
|
76 |
+
# dynamic_resolution: false
|
77 |
+
# batch_size: 1
|
78 |
+
####################################################################################################### DETAILS SLIDER
|
79 |
+
# - target: "" # what word for erasing the positive concept from
|
80 |
+
# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase
|
81 |
+
# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
|
82 |
+
# neutral: "" # starting point for conditioning the target
|
83 |
+
# action: "enhance" # erase or enhance
|
84 |
+
# guidance_scale: 4
|
85 |
+
# resolution: 512
|
86 |
+
# dynamic_resolution: false
|
87 |
+
# batch_size: 1
|
88 |
+
####################################################################################################### BOKEH SLIDER
|
89 |
+
# - target: "" # what word for erasing the positive concept from
|
90 |
+
# positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
|
91 |
+
# # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
|
92 |
+
# unconditional: ""
|
93 |
+
# neutral: "" # starting point for conditioning the target
|
94 |
+
# action: "enhance" # erase or enhance
|
95 |
+
# guidance_scale: 4
|
96 |
+
# resolution: 512
|
97 |
+
# dynamic_resolution: false
|
98 |
+
# batch_size: 1
|
99 |
+
####################################################################################################### LONG HAIR SLIDER
|
100 |
+
# - target: "male person" # what word for erasing the positive concept from
|
101 |
+
# positive: "male person, with long hair" # concept to erase
|
102 |
+
# unconditional: "male person, with short hair" # word to take the difference from the positive concept
|
103 |
+
# neutral: "male person" # starting point for conditioning the target
|
104 |
+
# action: "enhance" # erase or enhance
|
105 |
+
# guidance_scale: 4
|
106 |
+
# resolution: 512
|
107 |
+
# dynamic_resolution: false
|
108 |
+
# batch_size: 1
|
109 |
+
# - target: "female person" # what word for erasing the positive concept from
|
110 |
+
# positive: "female person, with long hair" # concept to erase
|
111 |
+
# unconditional: "female person, with short hair" # word to take the difference from the positive concept
|
112 |
+
# neutral: "female person" # starting point for conditioning the target
|
113 |
+
# action: "enhance" # erase or enhance
|
114 |
+
# guidance_scale: 4
|
115 |
+
# resolution: 512
|
116 |
+
# dynamic_resolution: false
|
117 |
+
# batch_size: 1
|
118 |
+
####################################################################################################### IMAGE SLIDER
|
119 |
+
- target: "" # what word for erasing the positive concept from
|
120 |
+
positive: "" # concept to erase
|
121 |
+
unconditional: "" # word to take the difference from the positive concept
|
122 |
+
neutral: "" # starting point for conditioning the target
|
123 |
+
action: "enhance" # erase or enhance
|
124 |
+
guidance_scale: 4
|
125 |
+
resolution: 512
|
126 |
+
dynamic_resolution: false
|
127 |
+
batch_size: 1
|
128 |
+
####################################################################################################### IMAGE SLIDER
|
129 |
+
# - target: "food" # what word for erasing the positive concept from
|
130 |
+
# positive: "food, expensive and fine dining" # concept to erase
|
131 |
+
# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
|
132 |
+
# neutral: "food" # starting point for conditioning the target
|
133 |
+
# action: "enhance" # erase or enhance
|
134 |
+
# guidance_scale: 4
|
135 |
+
# resolution: 512
|
136 |
+
# dynamic_resolution: false
|
137 |
+
# batch_size: 1
|
138 |
+
# - target: "room" # what word for erasing the positive concept from
|
139 |
+
# positive: "room, dirty disorganised and cluttered" # concept to erase
|
140 |
+
# unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
|
141 |
+
# neutral: "room" # starting point for conditioning the target
|
142 |
+
# action: "enhance" # erase or enhance
|
143 |
+
# guidance_scale: 4
|
144 |
+
# resolution: 512
|
145 |
+
# dynamic_resolution: false
|
146 |
+
# batch_size: 1
|
147 |
+
# - target: "male person" # what word for erasing the positive concept from
|
148 |
+
# positive: "male person, with a surprised look" # concept to erase
|
149 |
+
# unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
|
150 |
+
# neutral: "male person" # starting point for conditioning the target
|
151 |
+
# action: "enhance" # erase or enhance
|
152 |
+
# guidance_scale: 4
|
153 |
+
# resolution: 512
|
154 |
+
# dynamic_resolution: false
|
155 |
+
# batch_size: 1
|
156 |
+
# - target: "female person" # what word for erasing the positive concept from
|
157 |
+
# positive: "female person, with a surprised look" # concept to erase
|
158 |
+
# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
|
159 |
+
# neutral: "female person" # starting point for conditioning the target
|
160 |
+
# action: "enhance" # erase or enhance
|
161 |
+
# guidance_scale: 4
|
162 |
+
# resolution: 512
|
163 |
+
# dynamic_resolution: false
|
164 |
+
# batch_size: 1
|
165 |
+
# - target: "sky" # what word for erasing the positive concept from
|
166 |
+
# positive: "peaceful sky" # concept to erase
|
167 |
+
# unconditional: "sky" # word to take the difference from the positive concept
|
168 |
+
# neutral: "sky" # starting point for conditioning the target
|
169 |
+
# action: "enhance" # erase or enhance
|
170 |
+
# guidance_scale: 4
|
171 |
+
# resolution: 512
|
172 |
+
# dynamic_resolution: false
|
173 |
+
# batch_size: 1
|
174 |
+
# - target: "sky" # what word for erasing the positive concept from
|
175 |
+
# positive: "chaotic dark sky" # concept to erase
|
176 |
+
# unconditional: "sky" # word to take the difference from the positive concept
|
177 |
+
# neutral: "sky" # starting point for conditioning the target
|
178 |
+
# action: "erase" # erase or enhance
|
179 |
+
# guidance_scale: 4
|
180 |
+
# resolution: 512
|
181 |
+
# dynamic_resolution: false
|
182 |
+
# batch_size: 1
|
183 |
+
# - target: "person" # what word for erasing the positive concept from
|
184 |
+
# positive: "person, very young" # concept to erase
|
185 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
186 |
+
# neutral: "person" # starting point for conditioning the target
|
187 |
+
# action: "erase" # erase or enhance
|
188 |
+
# guidance_scale: 4
|
189 |
+
# resolution: 512
|
190 |
+
# dynamic_resolution: false
|
191 |
+
# batch_size: 1
|
192 |
+
# overweight
|
193 |
+
# - target: "art" # what word for erasing the positive concept from
|
194 |
+
# positive: "realistic art" # concept to erase
|
195 |
+
# unconditional: "art" # word to take the difference from the positive concept
|
196 |
+
# neutral: "art" # starting point for conditioning the target
|
197 |
+
# action: "enhance" # erase or enhance
|
198 |
+
# guidance_scale: 4
|
199 |
+
# resolution: 512
|
200 |
+
# dynamic_resolution: false
|
201 |
+
# batch_size: 1
|
202 |
+
# - target: "art" # what word for erasing the positive concept from
|
203 |
+
# positive: "abstract art" # concept to erase
|
204 |
+
# unconditional: "art" # word to take the difference from the positive concept
|
205 |
+
# neutral: "art" # starting point for conditioning the target
|
206 |
+
# action: "erase" # erase or enhance
|
207 |
+
# guidance_scale: 4
|
208 |
+
# resolution: 512
|
209 |
+
# dynamic_resolution: false
|
210 |
+
# batch_size: 1
|
211 |
+
# sky
|
212 |
+
# - target: "weather" # what word for erasing the positive concept from
|
213 |
+
# positive: "bright pleasant weather" # concept to erase
|
214 |
+
# unconditional: "weather" # word to take the difference from the positive concept
|
215 |
+
# neutral: "weather" # starting point for conditioning the target
|
216 |
+
# action: "enhance" # erase or enhance
|
217 |
+
# guidance_scale: 4
|
218 |
+
# resolution: 512
|
219 |
+
# dynamic_resolution: false
|
220 |
+
# batch_size: 1
|
221 |
+
# - target: "weather" # what word for erasing the positive concept from
|
222 |
+
# positive: "dark gloomy weather" # concept to erase
|
223 |
+
# unconditional: "weather" # word to take the difference from the positive concept
|
224 |
+
# neutral: "weather" # starting point for conditioning the target
|
225 |
+
# action: "erase" # erase or enhance
|
226 |
+
# guidance_scale: 4
|
227 |
+
# resolution: 512
|
228 |
+
# dynamic_resolution: false
|
229 |
+
# batch_size: 1
|
230 |
+
# hair
|
231 |
+
# - target: "person" # what word for erasing the positive concept from
|
232 |
+
# positive: "person with long hair" # concept to erase
|
233 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
234 |
+
# neutral: "person" # starting point for conditioning the target
|
235 |
+
# action: "enhance" # erase or enhance
|
236 |
+
# guidance_scale: 4
|
237 |
+
# resolution: 512
|
238 |
+
# dynamic_resolution: false
|
239 |
+
# batch_size: 1
|
240 |
+
# - target: "person" # what word for erasing the positive concept from
|
241 |
+
# positive: "person with short hair" # concept to erase
|
242 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
243 |
+
# neutral: "person" # starting point for conditioning the target
|
244 |
+
# action: "erase" # erase or enhance
|
245 |
+
# guidance_scale: 4
|
246 |
+
# resolution: 512
|
247 |
+
# dynamic_resolution: false
|
248 |
+
# batch_size: 1
|
249 |
+
# - target: "girl" # what word for erasing the positive concept from
|
250 |
+
# positive: "baby girl" # concept to erase
|
251 |
+
# unconditional: "girl" # word to take the difference from the positive concept
|
252 |
+
# neutral: "girl" # starting point for conditioning the target
|
253 |
+
# action: "enhance" # erase or enhance
|
254 |
+
# guidance_scale: -4
|
255 |
+
# resolution: 512
|
256 |
+
# dynamic_resolution: false
|
257 |
+
# batch_size: 1
|
258 |
+
# - target: "boy" # what word for erasing the positive concept from
|
259 |
+
# positive: "old man" # concept to erase
|
260 |
+
# unconditional: "boy" # word to take the difference from the positive concept
|
261 |
+
# neutral: "boy" # starting point for conditioning the target
|
262 |
+
# action: "enhance" # erase or enhance
|
263 |
+
# guidance_scale: 4
|
264 |
+
# resolution: 512
|
265 |
+
# dynamic_resolution: false
|
266 |
+
# batch_size: 1
|
267 |
+
# - target: "boy" # what word for erasing the positive concept from
|
268 |
+
# positive: "baby boy" # concept to erase
|
269 |
+
# unconditional: "boy" # word to take the difference from the positive concept
|
270 |
+
# neutral: "boy" # starting point for conditioning the target
|
271 |
+
# action: "enhance" # erase or enhance
|
272 |
+
# guidance_scale: -4
|
273 |
+
# resolution: 512
|
274 |
+
# dynamic_resolution: false
|
275 |
+
# batch_size: 1
|
trainscripts/imagesliders/data/prompts.yaml
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# - target: "person" # what word for erasing the positive concept from
|
2 |
+
# positive: "person, very old" # concept to erase
|
3 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
4 |
+
# neutral: "person" # starting point for conditioning the target
|
5 |
+
# action: "enhance" # erase or enhance
|
6 |
+
# guidance_scale: 4
|
7 |
+
# resolution: 512
|
8 |
+
# dynamic_resolution: false
|
9 |
+
# batch_size: 1
|
10 |
+
- target: "" # what word for erasing the positive concept from
|
11 |
+
positive: "" # concept to erase
|
12 |
+
unconditional: "" # word to take the difference from the positive concept
|
13 |
+
neutral: "" # starting point for conditioning the target
|
14 |
+
action: "enhance" # erase or enhance
|
15 |
+
guidance_scale: 1
|
16 |
+
resolution: 512
|
17 |
+
dynamic_resolution: false
|
18 |
+
batch_size: 1
|
19 |
+
# - target: "" # what word for erasing the positive concept from
|
20 |
+
# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
|
21 |
+
# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
|
22 |
+
# neutral: "" # starting point for conditioning the target
|
23 |
+
# action: "enhance" # erase or enhance
|
24 |
+
# guidance_scale: 4
|
25 |
+
# resolution: 512
|
26 |
+
# dynamic_resolution: false
|
27 |
+
# batch_size: 1
|
28 |
+
# - target: "food" # what word for erasing the positive concept from
|
29 |
+
# positive: "food, expensive and fine dining" # concept to erase
|
30 |
+
# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
|
31 |
+
# neutral: "food" # starting point for conditioning the target
|
32 |
+
# action: "enhance" # erase or enhance
|
33 |
+
# guidance_scale: 4
|
34 |
+
# resolution: 512
|
35 |
+
# dynamic_resolution: false
|
36 |
+
# batch_size: 1
|
37 |
+
# - target: "room" # what word for erasing the positive concept from
|
38 |
+
# positive: "room, dirty disorganised and cluttered" # concept to erase
|
39 |
+
# unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
|
40 |
+
# neutral: "room" # starting point for conditioning the target
|
41 |
+
# action: "enhance" # erase or enhance
|
42 |
+
# guidance_scale: 4
|
43 |
+
# resolution: 512
|
44 |
+
# dynamic_resolution: false
|
45 |
+
# batch_size: 1
|
46 |
+
# - target: "male person" # what word for erasing the positive concept from
|
47 |
+
# positive: "male person, with a surprised look" # concept to erase
|
48 |
+
# unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
|
49 |
+
# neutral: "male person" # starting point for conditioning the target
|
50 |
+
# action: "enhance" # erase or enhance
|
51 |
+
# guidance_scale: 4
|
52 |
+
# resolution: 512
|
53 |
+
# dynamic_resolution: false
|
54 |
+
# batch_size: 1
|
55 |
+
# - target: "female person" # what word for erasing the positive concept from
|
56 |
+
# positive: "female person, with a surprised look" # concept to erase
|
57 |
+
# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
|
58 |
+
# neutral: "female person" # starting point for conditioning the target
|
59 |
+
# action: "enhance" # erase or enhance
|
60 |
+
# guidance_scale: 4
|
61 |
+
# resolution: 512
|
62 |
+
# dynamic_resolution: false
|
63 |
+
# batch_size: 1
|
64 |
+
# - target: "sky" # what word for erasing the positive concept from
|
65 |
+
# positive: "peaceful sky" # concept to erase
|
66 |
+
# unconditional: "sky" # word to take the difference from the positive concept
|
67 |
+
# neutral: "sky" # starting point for conditioning the target
|
68 |
+
# action: "enhance" # erase or enhance
|
69 |
+
# guidance_scale: 4
|
70 |
+
# resolution: 512
|
71 |
+
# dynamic_resolution: false
|
72 |
+
# batch_size: 1
|
73 |
+
# - target: "sky" # what word for erasing the positive concept from
|
74 |
+
# positive: "chaotic dark sky" # concept to erase
|
75 |
+
# unconditional: "sky" # word to take the difference from the positive concept
|
76 |
+
# neutral: "sky" # starting point for conditioning the target
|
77 |
+
# action: "erase" # erase or enhance
|
78 |
+
# guidance_scale: 4
|
79 |
+
# resolution: 512
|
80 |
+
# dynamic_resolution: false
|
81 |
+
# batch_size: 1
|
82 |
+
# - target: "person" # what word for erasing the positive concept from
|
83 |
+
# positive: "person, very young" # concept to erase
|
84 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
85 |
+
# neutral: "person" # starting point for conditioning the target
|
86 |
+
# action: "erase" # erase or enhance
|
87 |
+
# guidance_scale: 4
|
88 |
+
# resolution: 512
|
89 |
+
# dynamic_resolution: false
|
90 |
+
# batch_size: 1
|
91 |
+
# overweight
|
92 |
+
# - target: "art" # what word for erasing the positive concept from
|
93 |
+
# positive: "realistic art" # concept to erase
|
94 |
+
# unconditional: "art" # word to take the difference from the positive concept
|
95 |
+
# neutral: "art" # starting point for conditioning the target
|
96 |
+
# action: "enhance" # erase or enhance
|
97 |
+
# guidance_scale: 4
|
98 |
+
# resolution: 512
|
99 |
+
# dynamic_resolution: false
|
100 |
+
# batch_size: 1
|
101 |
+
# - target: "art" # what word for erasing the positive concept from
|
102 |
+
# positive: "abstract art" # concept to erase
|
103 |
+
# unconditional: "art" # word to take the difference from the positive concept
|
104 |
+
# neutral: "art" # starting point for conditioning the target
|
105 |
+
# action: "erase" # erase or enhance
|
106 |
+
# guidance_scale: 4
|
107 |
+
# resolution: 512
|
108 |
+
# dynamic_resolution: false
|
109 |
+
# batch_size: 1
|
110 |
+
# sky
|
111 |
+
# - target: "weather" # what word for erasing the positive concept from
|
112 |
+
# positive: "bright pleasant weather" # concept to erase
|
113 |
+
# unconditional: "weather" # word to take the difference from the positive concept
|
114 |
+
# neutral: "weather" # starting point for conditioning the target
|
115 |
+
# action: "enhance" # erase or enhance
|
116 |
+
# guidance_scale: 4
|
117 |
+
# resolution: 512
|
118 |
+
# dynamic_resolution: false
|
119 |
+
# batch_size: 1
|
120 |
+
# - target: "weather" # what word for erasing the positive concept from
|
121 |
+
# positive: "dark gloomy weather" # concept to erase
|
122 |
+
# unconditional: "weather" # word to take the difference from the positive concept
|
123 |
+
# neutral: "weather" # starting point for conditioning the target
|
124 |
+
# action: "erase" # erase or enhance
|
125 |
+
# guidance_scale: 4
|
126 |
+
# resolution: 512
|
127 |
+
# dynamic_resolution: false
|
128 |
+
# batch_size: 1
|
129 |
+
# hair
|
130 |
+
# - target: "person" # what word for erasing the positive concept from
|
131 |
+
# positive: "person with long hair" # concept to erase
|
132 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
133 |
+
# neutral: "person" # starting point for conditioning the target
|
134 |
+
# action: "enhance" # erase or enhance
|
135 |
+
# guidance_scale: 4
|
136 |
+
# resolution: 512
|
137 |
+
# dynamic_resolution: false
|
138 |
+
# batch_size: 1
|
139 |
+
# - target: "person" # what word for erasing the positive concept from
|
140 |
+
# positive: "person with short hair" # concept to erase
|
141 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
142 |
+
# neutral: "person" # starting point for conditioning the target
|
143 |
+
# action: "erase" # erase or enhance
|
144 |
+
# guidance_scale: 4
|
145 |
+
# resolution: 512
|
146 |
+
# dynamic_resolution: false
|
147 |
+
# batch_size: 1
|
148 |
+
# - target: "girl" # what word for erasing the positive concept from
|
149 |
+
# positive: "baby girl" # concept to erase
|
150 |
+
# unconditional: "girl" # word to take the difference from the positive concept
|
151 |
+
# neutral: "girl" # starting point for conditioning the target
|
152 |
+
# action: "enhance" # erase or enhance
|
153 |
+
# guidance_scale: -4
|
154 |
+
# resolution: 512
|
155 |
+
# dynamic_resolution: false
|
156 |
+
# batch_size: 1
|
157 |
+
# - target: "boy" # what word for erasing the positive concept from
|
158 |
+
# positive: "old man" # concept to erase
|
159 |
+
# unconditional: "boy" # word to take the difference from the positive concept
|
160 |
+
# neutral: "boy" # starting point for conditioning the target
|
161 |
+
# action: "enhance" # erase or enhance
|
162 |
+
# guidance_scale: 4
|
163 |
+
# resolution: 512
|
164 |
+
# dynamic_resolution: false
|
165 |
+
# batch_size: 1
|
166 |
+
# - target: "boy" # what word for erasing the positive concept from
|
167 |
+
# positive: "baby boy" # concept to erase
|
168 |
+
# unconditional: "boy" # word to take the difference from the positive concept
|
169 |
+
# neutral: "boy" # starting point for conditioning the target
|
170 |
+
# action: "enhance" # erase or enhance
|
171 |
+
# guidance_scale: -4
|
172 |
+
# resolution: 512
|
173 |
+
# dynamic_resolution: false
|
174 |
+
# batch_size: 1
|
trainscripts/imagesliders/debug_util.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ใใใใฐ็จ...
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def check_requires_grad(model: torch.nn.Module):
|
7 |
+
for name, module in list(model.named_modules())[:5]:
|
8 |
+
if len(list(module.parameters())) > 0:
|
9 |
+
print(f"Module: {name}")
|
10 |
+
for name, param in list(module.named_parameters())[:2]:
|
11 |
+
print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
|
12 |
+
|
13 |
+
|
14 |
+
def check_training_mode(model: torch.nn.Module):
|
15 |
+
for name, module in list(model.named_modules())[:5]:
|
16 |
+
print(f"Module: {name}, Training Mode: {module.training}")
|
trainscripts/imagesliders/lora.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:
|
2 |
+
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
3 |
+
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
4 |
+
|
5 |
+
import os
|
6 |
+
import math
|
7 |
+
from typing import Optional, List, Type, Set, Literal
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from diffusers import UNet2DConditionModel
|
12 |
+
from safetensors.torch import save_file
|
13 |
+
|
14 |
+
|
15 |
+
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
|
16 |
+
# "Transformer2DModel", # ใฉใใใใใฃใกใฎๆนใใใ๏ผ # attn1, 2
|
17 |
+
"Attention"
|
18 |
+
]
|
19 |
+
UNET_TARGET_REPLACE_MODULE_CONV = [
|
20 |
+
"ResnetBlock2D",
|
21 |
+
"Downsample2D",
|
22 |
+
"Upsample2D",
|
23 |
+
# "DownBlock2D",
|
24 |
+
# "UpBlock2D"
|
25 |
+
] # locon, 3clier
|
26 |
+
|
27 |
+
LORA_PREFIX_UNET = "lora_unet"
|
28 |
+
|
29 |
+
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
|
30 |
+
|
31 |
+
TRAINING_METHODS = Literal[
|
32 |
+
"noxattn", # train all layers except x-attns and time_embed layers
|
33 |
+
"innoxattn", # train all layers except self attention layers
|
34 |
+
"selfattn", # ESD-u, train only self attention layers
|
35 |
+
"xattn", # ESD-x, train only x attention layers
|
36 |
+
"full", # train all layers
|
37 |
+
"xattn-strict", # q and k values
|
38 |
+
"noxattn-hspace",
|
39 |
+
"noxattn-hspace-last",
|
40 |
+
# "xlayer",
|
41 |
+
# "outxattn",
|
42 |
+
# "outsattn",
|
43 |
+
# "inxattn",
|
44 |
+
# "inmidsattn",
|
45 |
+
# "selflayer",
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
class LoRAModule(nn.Module):
|
50 |
+
"""
|
51 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
lora_name,
|
57 |
+
org_module: nn.Module,
|
58 |
+
multiplier=1.0,
|
59 |
+
lora_dim=4,
|
60 |
+
alpha=1,
|
61 |
+
):
|
62 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
63 |
+
super().__init__()
|
64 |
+
self.lora_name = lora_name
|
65 |
+
self.lora_dim = lora_dim
|
66 |
+
|
67 |
+
if "Linear" in org_module.__class__.__name__:
|
68 |
+
in_dim = org_module.in_features
|
69 |
+
out_dim = org_module.out_features
|
70 |
+
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
71 |
+
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
72 |
+
|
73 |
+
elif "Conv" in org_module.__class__.__name__: # ไธๅฟ
|
74 |
+
in_dim = org_module.in_channels
|
75 |
+
out_dim = org_module.out_channels
|
76 |
+
|
77 |
+
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
78 |
+
if self.lora_dim != lora_dim:
|
79 |
+
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
80 |
+
|
81 |
+
kernel_size = org_module.kernel_size
|
82 |
+
stride = org_module.stride
|
83 |
+
padding = org_module.padding
|
84 |
+
self.lora_down = nn.Conv2d(
|
85 |
+
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
|
86 |
+
)
|
87 |
+
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
88 |
+
|
89 |
+
if type(alpha) == torch.Tensor:
|
90 |
+
alpha = alpha.detach().numpy()
|
91 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
92 |
+
self.scale = alpha / self.lora_dim
|
93 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # ๅฎๆฐใจใใฆๆฑใใ
|
94 |
+
|
95 |
+
# same as microsoft's
|
96 |
+
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
97 |
+
nn.init.zeros_(self.lora_up.weight)
|
98 |
+
|
99 |
+
self.multiplier = multiplier
|
100 |
+
self.org_module = org_module # remove in applying
|
101 |
+
|
102 |
+
def apply_to(self):
|
103 |
+
self.org_forward = self.org_module.forward
|
104 |
+
self.org_module.forward = self.forward
|
105 |
+
del self.org_module
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return (
|
109 |
+
self.org_forward(x)
|
110 |
+
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
class LoRANetwork(nn.Module):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
unet: UNet2DConditionModel,
|
118 |
+
rank: int = 4,
|
119 |
+
multiplier: float = 1.0,
|
120 |
+
alpha: float = 1.0,
|
121 |
+
train_method: TRAINING_METHODS = "full",
|
122 |
+
) -> None:
|
123 |
+
super().__init__()
|
124 |
+
self.lora_scale = 1
|
125 |
+
self.multiplier = multiplier
|
126 |
+
self.lora_dim = rank
|
127 |
+
self.alpha = alpha
|
128 |
+
|
129 |
+
# LoRAใฎใฟ
|
130 |
+
self.module = LoRAModule
|
131 |
+
|
132 |
+
# unetใฎloraใไฝใ
|
133 |
+
self.unet_loras = self.create_modules(
|
134 |
+
LORA_PREFIX_UNET,
|
135 |
+
unet,
|
136 |
+
DEFAULT_TARGET_REPLACE,
|
137 |
+
self.lora_dim,
|
138 |
+
self.multiplier,
|
139 |
+
train_method=train_method,
|
140 |
+
)
|
141 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
142 |
+
|
143 |
+
# assertion ๅๅใฎ่ขซใใใชใใ็ขบ่ชใใฆใใใใใ
|
144 |
+
lora_names = set()
|
145 |
+
for lora in self.unet_loras:
|
146 |
+
assert (
|
147 |
+
lora.lora_name not in lora_names
|
148 |
+
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
|
149 |
+
lora_names.add(lora.lora_name)
|
150 |
+
|
151 |
+
# ้ฉ็จใใ
|
152 |
+
for lora in self.unet_loras:
|
153 |
+
lora.apply_to()
|
154 |
+
self.add_module(
|
155 |
+
lora.lora_name,
|
156 |
+
lora,
|
157 |
+
)
|
158 |
+
|
159 |
+
del unet
|
160 |
+
|
161 |
+
torch.cuda.empty_cache()
|
162 |
+
|
163 |
+
def create_modules(
|
164 |
+
self,
|
165 |
+
prefix: str,
|
166 |
+
root_module: nn.Module,
|
167 |
+
target_replace_modules: List[str],
|
168 |
+
rank: int,
|
169 |
+
multiplier: float,
|
170 |
+
train_method: TRAINING_METHODS,
|
171 |
+
) -> list:
|
172 |
+
loras = []
|
173 |
+
names = []
|
174 |
+
for name, module in root_module.named_modules():
|
175 |
+
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention ใจ Time Embed ไปฅๅคๅญฆ็ฟ
|
176 |
+
if "attn2" in name or "time_embed" in name:
|
177 |
+
continue
|
178 |
+
elif train_method == "innoxattn": # Cross Attention ไปฅๅคๅญฆ็ฟ
|
179 |
+
if "attn2" in name:
|
180 |
+
continue
|
181 |
+
elif train_method == "selfattn": # Self Attention ใฎใฟๅญฆ็ฟ
|
182 |
+
if "attn1" not in name:
|
183 |
+
continue
|
184 |
+
elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention ใฎใฟๅญฆ็ฟ
|
185 |
+
if "attn2" not in name:
|
186 |
+
continue
|
187 |
+
elif train_method == "full": # ๅ
จ้จๅญฆ็ฟ
|
188 |
+
pass
|
189 |
+
else:
|
190 |
+
raise NotImplementedError(
|
191 |
+
f"train_method: {train_method} is not implemented."
|
192 |
+
)
|
193 |
+
if module.__class__.__name__ in target_replace_modules:
|
194 |
+
for child_name, child_module in module.named_modules():
|
195 |
+
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
|
196 |
+
if train_method == 'xattn-strict':
|
197 |
+
if 'out' in child_name:
|
198 |
+
continue
|
199 |
+
if train_method == 'noxattn-hspace':
|
200 |
+
if 'mid_block' not in name:
|
201 |
+
continue
|
202 |
+
if train_method == 'noxattn-hspace-last':
|
203 |
+
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
|
204 |
+
continue
|
205 |
+
lora_name = prefix + "." + name + "." + child_name
|
206 |
+
lora_name = lora_name.replace(".", "_")
|
207 |
+
# print(f"{lora_name}")
|
208 |
+
lora = self.module(
|
209 |
+
lora_name, child_module, multiplier, rank, self.alpha
|
210 |
+
)
|
211 |
+
# print(name, child_name)
|
212 |
+
# print(child_module.weight.shape)
|
213 |
+
loras.append(lora)
|
214 |
+
names.append(lora_name)
|
215 |
+
# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
|
216 |
+
return loras
|
217 |
+
|
218 |
+
def prepare_optimizer_params(self):
|
219 |
+
all_params = []
|
220 |
+
|
221 |
+
if self.unet_loras: # ๅฎ่ณชใใใใใชใ
|
222 |
+
params = []
|
223 |
+
[params.extend(lora.parameters()) for lora in self.unet_loras]
|
224 |
+
param_data = {"params": params}
|
225 |
+
all_params.append(param_data)
|
226 |
+
|
227 |
+
return all_params
|
228 |
+
|
229 |
+
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
|
230 |
+
state_dict = self.state_dict()
|
231 |
+
|
232 |
+
if dtype is not None:
|
233 |
+
for key in list(state_dict.keys()):
|
234 |
+
v = state_dict[key]
|
235 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
236 |
+
state_dict[key] = v
|
237 |
+
|
238 |
+
# for key in list(state_dict.keys()):
|
239 |
+
# if not key.startswith("lora"):
|
240 |
+
# # loraไปฅๅค้คๅค
|
241 |
+
# del state_dict[key]
|
242 |
+
|
243 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
244 |
+
save_file(state_dict, file, metadata)
|
245 |
+
else:
|
246 |
+
torch.save(state_dict, file)
|
247 |
+
def set_lora_slider(self, scale):
|
248 |
+
self.lora_scale = scale
|
249 |
+
|
250 |
+
def __enter__(self):
|
251 |
+
for lora in self.unet_loras:
|
252 |
+
lora.multiplier = 1.0 * self.lora_scale
|
253 |
+
|
254 |
+
def __exit__(self, exc_type, exc_value, tb):
|
255 |
+
for lora in self.unet_loras:
|
256 |
+
lora.multiplier = 0
|
trainscripts/imagesliders/model_util.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
5 |
+
from diffusers import (
|
6 |
+
UNet2DConditionModel,
|
7 |
+
SchedulerMixin,
|
8 |
+
StableDiffusionPipeline,
|
9 |
+
StableDiffusionXLPipeline,
|
10 |
+
AutoencoderKL,
|
11 |
+
)
|
12 |
+
from diffusers.schedulers import (
|
13 |
+
DDIMScheduler,
|
14 |
+
DDPMScheduler,
|
15 |
+
LMSDiscreteScheduler,
|
16 |
+
EulerAncestralDiscreteScheduler,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
|
21 |
+
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
|
22 |
+
|
23 |
+
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
|
24 |
+
|
25 |
+
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
|
26 |
+
|
27 |
+
DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
|
28 |
+
|
29 |
+
|
30 |
+
def load_diffusers_model(
|
31 |
+
pretrained_model_name_or_path: str,
|
32 |
+
v2: bool = False,
|
33 |
+
clip_skip: Optional[int] = None,
|
34 |
+
weight_dtype: torch.dtype = torch.float32,
|
35 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
36 |
+
# VAE ใฏใใใชใ
|
37 |
+
|
38 |
+
if v2:
|
39 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
40 |
+
TOKENIZER_V2_MODEL_NAME,
|
41 |
+
subfolder="tokenizer",
|
42 |
+
torch_dtype=weight_dtype,
|
43 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
44 |
+
)
|
45 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
46 |
+
pretrained_model_name_or_path,
|
47 |
+
subfolder="text_encoder",
|
48 |
+
# default is clip skip 2
|
49 |
+
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
|
50 |
+
torch_dtype=weight_dtype,
|
51 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
55 |
+
TOKENIZER_V1_MODEL_NAME,
|
56 |
+
subfolder="tokenizer",
|
57 |
+
torch_dtype=weight_dtype,
|
58 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
59 |
+
)
|
60 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
61 |
+
pretrained_model_name_or_path,
|
62 |
+
subfolder="text_encoder",
|
63 |
+
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
|
64 |
+
torch_dtype=weight_dtype,
|
65 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
66 |
+
)
|
67 |
+
|
68 |
+
unet = UNet2DConditionModel.from_pretrained(
|
69 |
+
pretrained_model_name_or_path,
|
70 |
+
subfolder="unet",
|
71 |
+
torch_dtype=weight_dtype,
|
72 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
73 |
+
)
|
74 |
+
|
75 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
76 |
+
|
77 |
+
return tokenizer, text_encoder, unet, vae
|
78 |
+
|
79 |
+
|
80 |
+
def load_checkpoint_model(
|
81 |
+
checkpoint_path: str,
|
82 |
+
v2: bool = False,
|
83 |
+
clip_skip: Optional[int] = None,
|
84 |
+
weight_dtype: torch.dtype = torch.float32,
|
85 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
86 |
+
pipe = StableDiffusionPipeline.from_ckpt(
|
87 |
+
checkpoint_path,
|
88 |
+
upcast_attention=True if v2 else False,
|
89 |
+
torch_dtype=weight_dtype,
|
90 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
91 |
+
)
|
92 |
+
|
93 |
+
unet = pipe.unet
|
94 |
+
tokenizer = pipe.tokenizer
|
95 |
+
text_encoder = pipe.text_encoder
|
96 |
+
vae = pipe.vae
|
97 |
+
if clip_skip is not None:
|
98 |
+
if v2:
|
99 |
+
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
|
100 |
+
else:
|
101 |
+
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
|
102 |
+
|
103 |
+
del pipe
|
104 |
+
|
105 |
+
return tokenizer, text_encoder, unet, vae
|
106 |
+
|
107 |
+
|
108 |
+
def load_models(
|
109 |
+
pretrained_model_name_or_path: str,
|
110 |
+
scheduler_name: AVAILABLE_SCHEDULERS,
|
111 |
+
v2: bool = False,
|
112 |
+
v_pred: bool = False,
|
113 |
+
weight_dtype: torch.dtype = torch.float32,
|
114 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
|
115 |
+
if pretrained_model_name_or_path.endswith(
|
116 |
+
".ckpt"
|
117 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
118 |
+
tokenizer, text_encoder, unet, vae = load_checkpoint_model(
|
119 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
120 |
+
)
|
121 |
+
else: # diffusers
|
122 |
+
tokenizer, text_encoder, unet, vae = load_diffusers_model(
|
123 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
124 |
+
)
|
125 |
+
|
126 |
+
# VAE ใฏใใใชใ
|
127 |
+
|
128 |
+
scheduler = create_noise_scheduler(
|
129 |
+
scheduler_name,
|
130 |
+
prediction_type="v_prediction" if v_pred else "epsilon",
|
131 |
+
)
|
132 |
+
|
133 |
+
return tokenizer, text_encoder, unet, scheduler, vae
|
134 |
+
|
135 |
+
|
136 |
+
def load_diffusers_model_xl(
|
137 |
+
pretrained_model_name_or_path: str,
|
138 |
+
weight_dtype: torch.dtype = torch.float32,
|
139 |
+
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
140 |
+
# returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
|
141 |
+
|
142 |
+
tokenizers = [
|
143 |
+
CLIPTokenizer.from_pretrained(
|
144 |
+
pretrained_model_name_or_path,
|
145 |
+
subfolder="tokenizer",
|
146 |
+
torch_dtype=weight_dtype,
|
147 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
148 |
+
),
|
149 |
+
CLIPTokenizer.from_pretrained(
|
150 |
+
pretrained_model_name_or_path,
|
151 |
+
subfolder="tokenizer_2",
|
152 |
+
torch_dtype=weight_dtype,
|
153 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
154 |
+
pad_token_id=0, # same as open clip
|
155 |
+
),
|
156 |
+
]
|
157 |
+
|
158 |
+
text_encoders = [
|
159 |
+
CLIPTextModel.from_pretrained(
|
160 |
+
pretrained_model_name_or_path,
|
161 |
+
subfolder="text_encoder",
|
162 |
+
torch_dtype=weight_dtype,
|
163 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
164 |
+
),
|
165 |
+
CLIPTextModelWithProjection.from_pretrained(
|
166 |
+
pretrained_model_name_or_path,
|
167 |
+
subfolder="text_encoder_2",
|
168 |
+
torch_dtype=weight_dtype,
|
169 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
170 |
+
),
|
171 |
+
]
|
172 |
+
|
173 |
+
unet = UNet2DConditionModel.from_pretrained(
|
174 |
+
pretrained_model_name_or_path,
|
175 |
+
subfolder="unet",
|
176 |
+
torch_dtype=weight_dtype,
|
177 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
178 |
+
)
|
179 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
180 |
+
return tokenizers, text_encoders, unet, vae
|
181 |
+
|
182 |
+
|
183 |
+
def load_checkpoint_model_xl(
|
184 |
+
checkpoint_path: str,
|
185 |
+
weight_dtype: torch.dtype = torch.float32,
|
186 |
+
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
187 |
+
pipe = StableDiffusionXLPipeline.from_single_file(
|
188 |
+
checkpoint_path,
|
189 |
+
torch_dtype=weight_dtype,
|
190 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
191 |
+
)
|
192 |
+
|
193 |
+
unet = pipe.unet
|
194 |
+
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
195 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
196 |
+
if len(text_encoders) == 2:
|
197 |
+
text_encoders[1].pad_token_id = 0
|
198 |
+
|
199 |
+
del pipe
|
200 |
+
|
201 |
+
return tokenizers, text_encoders, unet
|
202 |
+
|
203 |
+
|
204 |
+
def load_models_xl(
|
205 |
+
pretrained_model_name_or_path: str,
|
206 |
+
scheduler_name: AVAILABLE_SCHEDULERS,
|
207 |
+
weight_dtype: torch.dtype = torch.float32,
|
208 |
+
) -> tuple[
|
209 |
+
list[CLIPTokenizer],
|
210 |
+
list[SDXL_TEXT_ENCODER_TYPE],
|
211 |
+
UNet2DConditionModel,
|
212 |
+
SchedulerMixin,
|
213 |
+
]:
|
214 |
+
if pretrained_model_name_or_path.endswith(
|
215 |
+
".ckpt"
|
216 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
217 |
+
(
|
218 |
+
tokenizers,
|
219 |
+
text_encoders,
|
220 |
+
unet,
|
221 |
+
) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
|
222 |
+
else: # diffusers
|
223 |
+
(
|
224 |
+
tokenizers,
|
225 |
+
text_encoders,
|
226 |
+
unet,
|
227 |
+
vae
|
228 |
+
) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
|
229 |
+
|
230 |
+
scheduler = create_noise_scheduler(scheduler_name)
|
231 |
+
|
232 |
+
return tokenizers, text_encoders, unet, scheduler, vae
|
233 |
+
|
234 |
+
|
235 |
+
def create_noise_scheduler(
|
236 |
+
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
|
237 |
+
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
238 |
+
) -> SchedulerMixin:
|
239 |
+
# ๆญฃ็ดใใฉใใใใใฎใใใใใชใใๅ
ใฎๅฎ่ฃ
ใ ใจDDIMใจDDPMใจLMSใ้ธในใใฎใ ใใฉใใฉใใใใใฎใใใใใฌใ
|
240 |
+
|
241 |
+
name = scheduler_name.lower().replace(" ", "_")
|
242 |
+
if name == "ddim":
|
243 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
|
244 |
+
scheduler = DDIMScheduler(
|
245 |
+
beta_start=0.00085,
|
246 |
+
beta_end=0.012,
|
247 |
+
beta_schedule="scaled_linear",
|
248 |
+
num_train_timesteps=1000,
|
249 |
+
clip_sample=False,
|
250 |
+
prediction_type=prediction_type, # ใใใงใใใฎ๏ผ
|
251 |
+
)
|
252 |
+
elif name == "ddpm":
|
253 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
|
254 |
+
scheduler = DDPMScheduler(
|
255 |
+
beta_start=0.00085,
|
256 |
+
beta_end=0.012,
|
257 |
+
beta_schedule="scaled_linear",
|
258 |
+
num_train_timesteps=1000,
|
259 |
+
clip_sample=False,
|
260 |
+
prediction_type=prediction_type,
|
261 |
+
)
|
262 |
+
elif name == "lms":
|
263 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
|
264 |
+
scheduler = LMSDiscreteScheduler(
|
265 |
+
beta_start=0.00085,
|
266 |
+
beta_end=0.012,
|
267 |
+
beta_schedule="scaled_linear",
|
268 |
+
num_train_timesteps=1000,
|
269 |
+
prediction_type=prediction_type,
|
270 |
+
)
|
271 |
+
elif name == "euler_a":
|
272 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
|
273 |
+
scheduler = EulerAncestralDiscreteScheduler(
|
274 |
+
beta_start=0.00085,
|
275 |
+
beta_end=0.012,
|
276 |
+
beta_schedule="scaled_linear",
|
277 |
+
num_train_timesteps=1000,
|
278 |
+
prediction_type=prediction_type,
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
raise ValueError(f"Unknown scheduler name: {name}")
|
282 |
+
|
283 |
+
return scheduler
|
trainscripts/imagesliders/prompt_util.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional, Union, List
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
from pydantic import BaseModel, root_validator
|
8 |
+
import torch
|
9 |
+
import copy
|
10 |
+
|
11 |
+
ACTION_TYPES = Literal[
|
12 |
+
"erase",
|
13 |
+
"enhance",
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
# XL ใฏไบ็จฎ้กๅฟ
่ฆใชใฎใง
|
18 |
+
class PromptEmbedsXL:
|
19 |
+
text_embeds: torch.FloatTensor
|
20 |
+
pooled_embeds: torch.FloatTensor
|
21 |
+
|
22 |
+
def __init__(self, *args) -> None:
|
23 |
+
self.text_embeds = args[0]
|
24 |
+
self.pooled_embeds = args[1]
|
25 |
+
|
26 |
+
|
27 |
+
# SDv1.x, SDv2.x ใฏ FloatTensorใXL ใฏ PromptEmbedsXL
|
28 |
+
PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
|
29 |
+
|
30 |
+
|
31 |
+
class PromptEmbedsCache: # ไฝฟใใพใใใใใฎใง
|
32 |
+
prompts: dict[str, PROMPT_EMBEDDING] = {}
|
33 |
+
|
34 |
+
def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
|
35 |
+
self.prompts[__name] = __value
|
36 |
+
|
37 |
+
def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
|
38 |
+
if __name in self.prompts:
|
39 |
+
return self.prompts[__name]
|
40 |
+
else:
|
41 |
+
return None
|
42 |
+
|
43 |
+
|
44 |
+
class PromptSettings(BaseModel): # yaml ใฎใใค
|
45 |
+
target: str
|
46 |
+
positive: str = None # if None, target will be used
|
47 |
+
unconditional: str = "" # default is ""
|
48 |
+
neutral: str = None # if None, unconditional will be used
|
49 |
+
action: ACTION_TYPES = "erase" # default is "erase"
|
50 |
+
guidance_scale: float = 1.0 # default is 1.0
|
51 |
+
resolution: int = 512 # default is 512
|
52 |
+
dynamic_resolution: bool = False # default is False
|
53 |
+
batch_size: int = 1 # default is 1
|
54 |
+
dynamic_crops: bool = False # default is False. only used when model is XL
|
55 |
+
|
56 |
+
@root_validator(pre=True)
|
57 |
+
def fill_prompts(cls, values):
|
58 |
+
keys = values.keys()
|
59 |
+
if "target" not in keys:
|
60 |
+
raise ValueError("target must be specified")
|
61 |
+
if "positive" not in keys:
|
62 |
+
values["positive"] = values["target"]
|
63 |
+
if "unconditional" not in keys:
|
64 |
+
values["unconditional"] = ""
|
65 |
+
if "neutral" not in keys:
|
66 |
+
values["neutral"] = values["unconditional"]
|
67 |
+
|
68 |
+
return values
|
69 |
+
|
70 |
+
|
71 |
+
class PromptEmbedsPair:
|
72 |
+
target: PROMPT_EMBEDDING # not want to generate the concept
|
73 |
+
positive: PROMPT_EMBEDDING # generate the concept
|
74 |
+
unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
|
75 |
+
neutral: PROMPT_EMBEDDING # base condition (default should be empty)
|
76 |
+
|
77 |
+
guidance_scale: float
|
78 |
+
resolution: int
|
79 |
+
dynamic_resolution: bool
|
80 |
+
batch_size: int
|
81 |
+
dynamic_crops: bool
|
82 |
+
|
83 |
+
loss_fn: torch.nn.Module
|
84 |
+
action: ACTION_TYPES
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
loss_fn: torch.nn.Module,
|
89 |
+
target: PROMPT_EMBEDDING,
|
90 |
+
positive: PROMPT_EMBEDDING,
|
91 |
+
unconditional: PROMPT_EMBEDDING,
|
92 |
+
neutral: PROMPT_EMBEDDING,
|
93 |
+
settings: PromptSettings,
|
94 |
+
) -> None:
|
95 |
+
self.loss_fn = loss_fn
|
96 |
+
self.target = target
|
97 |
+
self.positive = positive
|
98 |
+
self.unconditional = unconditional
|
99 |
+
self.neutral = neutral
|
100 |
+
|
101 |
+
self.guidance_scale = settings.guidance_scale
|
102 |
+
self.resolution = settings.resolution
|
103 |
+
self.dynamic_resolution = settings.dynamic_resolution
|
104 |
+
self.batch_size = settings.batch_size
|
105 |
+
self.dynamic_crops = settings.dynamic_crops
|
106 |
+
self.action = settings.action
|
107 |
+
|
108 |
+
def _erase(
|
109 |
+
self,
|
110 |
+
target_latents: torch.FloatTensor, # "van gogh"
|
111 |
+
positive_latents: torch.FloatTensor, # "van gogh"
|
112 |
+
unconditional_latents: torch.FloatTensor, # ""
|
113 |
+
neutral_latents: torch.FloatTensor, # ""
|
114 |
+
) -> torch.FloatTensor:
|
115 |
+
"""Target latents are going not to have the positive concept."""
|
116 |
+
return self.loss_fn(
|
117 |
+
target_latents,
|
118 |
+
neutral_latents
|
119 |
+
- self.guidance_scale * (positive_latents - unconditional_latents)
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def _enhance(
|
124 |
+
self,
|
125 |
+
target_latents: torch.FloatTensor, # "van gogh"
|
126 |
+
positive_latents: torch.FloatTensor, # "van gogh"
|
127 |
+
unconditional_latents: torch.FloatTensor, # ""
|
128 |
+
neutral_latents: torch.FloatTensor, # ""
|
129 |
+
):
|
130 |
+
"""Target latents are going to have the positive concept."""
|
131 |
+
return self.loss_fn(
|
132 |
+
target_latents,
|
133 |
+
neutral_latents
|
134 |
+
+ self.guidance_scale * (positive_latents - unconditional_latents)
|
135 |
+
)
|
136 |
+
|
137 |
+
def loss(
|
138 |
+
self,
|
139 |
+
**kwargs,
|
140 |
+
):
|
141 |
+
if self.action == "erase":
|
142 |
+
return self._erase(**kwargs)
|
143 |
+
|
144 |
+
elif self.action == "enhance":
|
145 |
+
return self._enhance(**kwargs)
|
146 |
+
|
147 |
+
else:
|
148 |
+
raise ValueError("action must be erase or enhance")
|
149 |
+
|
150 |
+
|
151 |
+
def load_prompts_from_yaml(path, attributes = []):
|
152 |
+
with open(path, "r") as f:
|
153 |
+
prompts = yaml.safe_load(f)
|
154 |
+
print(prompts)
|
155 |
+
if len(prompts) == 0:
|
156 |
+
raise ValueError("prompts file is empty")
|
157 |
+
if len(attributes)!=0:
|
158 |
+
newprompts = []
|
159 |
+
for i in range(len(prompts)):
|
160 |
+
for att in attributes:
|
161 |
+
copy_ = copy.deepcopy(prompts[i])
|
162 |
+
copy_['target'] = att + ' ' + copy_['target']
|
163 |
+
copy_['positive'] = att + ' ' + copy_['positive']
|
164 |
+
copy_['neutral'] = att + ' ' + copy_['neutral']
|
165 |
+
copy_['unconditional'] = att + ' ' + copy_['unconditional']
|
166 |
+
newprompts.append(copy_)
|
167 |
+
else:
|
168 |
+
newprompts = copy.deepcopy(prompts)
|
169 |
+
|
170 |
+
print(newprompts)
|
171 |
+
print(len(prompts), len(newprompts))
|
172 |
+
prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
|
173 |
+
|
174 |
+
return prompt_settings
|
trainscripts/imagesliders/train_lora-scale-xl.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:
|
2 |
+
# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
|
3 |
+
# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
|
4 |
+
|
5 |
+
from typing import List, Optional
|
6 |
+
import argparse
|
7 |
+
import ast
|
8 |
+
from pathlib import Path
|
9 |
+
import gc, os
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
import train_util
|
19 |
+
import random
|
20 |
+
import model_util
|
21 |
+
import prompt_util
|
22 |
+
from prompt_util import (
|
23 |
+
PromptEmbedsCache,
|
24 |
+
PromptEmbedsPair,
|
25 |
+
PromptSettings,
|
26 |
+
PromptEmbedsXL,
|
27 |
+
)
|
28 |
+
import debug_util
|
29 |
+
import config_util
|
30 |
+
from config_util import RootConfig
|
31 |
+
|
32 |
+
import wandb
|
33 |
+
|
34 |
+
NUM_IMAGES_PER_PROMPT = 1
|
35 |
+
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
|
36 |
+
|
37 |
+
def flush():
|
38 |
+
torch.cuda.empty_cache()
|
39 |
+
gc.collect()
|
40 |
+
|
41 |
+
|
42 |
+
def train(
|
43 |
+
config: RootConfig,
|
44 |
+
prompts: list[PromptSettings],
|
45 |
+
device,
|
46 |
+
folder_main: str,
|
47 |
+
folders,
|
48 |
+
scales,
|
49 |
+
|
50 |
+
):
|
51 |
+
scales = np.array(scales)
|
52 |
+
folders = np.array(folders)
|
53 |
+
scales_unique = list(scales)
|
54 |
+
|
55 |
+
metadata = {
|
56 |
+
"prompts": ",".join([prompt.json() for prompt in prompts]),
|
57 |
+
"config": config.json(),
|
58 |
+
}
|
59 |
+
save_path = Path(config.save.path)
|
60 |
+
|
61 |
+
modules = DEFAULT_TARGET_REPLACE
|
62 |
+
if config.network.type == "c3lier":
|
63 |
+
modules += UNET_TARGET_REPLACE_MODULE_CONV
|
64 |
+
|
65 |
+
if config.logging.verbose:
|
66 |
+
print(metadata)
|
67 |
+
|
68 |
+
if config.logging.use_wandb:
|
69 |
+
wandb.init(project=f"LECO_{config.save.name}", config=metadata)
|
70 |
+
|
71 |
+
weight_dtype = config_util.parse_precision(config.train.precision)
|
72 |
+
save_weight_dtype = config_util.parse_precision(config.train.precision)
|
73 |
+
|
74 |
+
(
|
75 |
+
tokenizers,
|
76 |
+
text_encoders,
|
77 |
+
unet,
|
78 |
+
noise_scheduler,
|
79 |
+
vae
|
80 |
+
) = model_util.load_models_xl(
|
81 |
+
config.pretrained_model.name_or_path,
|
82 |
+
scheduler_name=config.train.noise_scheduler,
|
83 |
+
)
|
84 |
+
|
85 |
+
for text_encoder in text_encoders:
|
86 |
+
text_encoder.to(device, dtype=weight_dtype)
|
87 |
+
text_encoder.requires_grad_(False)
|
88 |
+
text_encoder.eval()
|
89 |
+
|
90 |
+
unet.to(device, dtype=weight_dtype)
|
91 |
+
if config.other.use_xformers:
|
92 |
+
unet.enable_xformers_memory_efficient_attention()
|
93 |
+
unet.requires_grad_(False)
|
94 |
+
unet.eval()
|
95 |
+
|
96 |
+
vae.to(device)
|
97 |
+
vae.requires_grad_(False)
|
98 |
+
vae.eval()
|
99 |
+
|
100 |
+
network = LoRANetwork(
|
101 |
+
unet,
|
102 |
+
rank=config.network.rank,
|
103 |
+
multiplier=1.0,
|
104 |
+
alpha=config.network.alpha,
|
105 |
+
train_method=config.network.training_method,
|
106 |
+
).to(device, dtype=weight_dtype)
|
107 |
+
|
108 |
+
optimizer_module = train_util.get_optimizer(config.train.optimizer)
|
109 |
+
#optimizer_args
|
110 |
+
optimizer_kwargs = {}
|
111 |
+
if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
|
112 |
+
for arg in config.train.optimizer_args.split(" "):
|
113 |
+
key, value = arg.split("=")
|
114 |
+
value = ast.literal_eval(value)
|
115 |
+
optimizer_kwargs[key] = value
|
116 |
+
|
117 |
+
optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
|
118 |
+
lr_scheduler = train_util.get_lr_scheduler(
|
119 |
+
config.train.lr_scheduler,
|
120 |
+
optimizer,
|
121 |
+
max_iterations=config.train.iterations,
|
122 |
+
lr_min=config.train.lr / 100,
|
123 |
+
)
|
124 |
+
criteria = torch.nn.MSELoss()
|
125 |
+
|
126 |
+
print("Prompts")
|
127 |
+
for settings in prompts:
|
128 |
+
print(settings)
|
129 |
+
|
130 |
+
# debug
|
131 |
+
debug_util.check_requires_grad(network)
|
132 |
+
debug_util.check_training_mode(network)
|
133 |
+
|
134 |
+
cache = PromptEmbedsCache()
|
135 |
+
prompt_pairs: list[PromptEmbedsPair] = []
|
136 |
+
|
137 |
+
with torch.no_grad():
|
138 |
+
for settings in prompts:
|
139 |
+
print(settings)
|
140 |
+
for prompt in [
|
141 |
+
settings.target,
|
142 |
+
settings.positive,
|
143 |
+
settings.neutral,
|
144 |
+
settings.unconditional,
|
145 |
+
]:
|
146 |
+
if cache[prompt] == None:
|
147 |
+
tex_embs, pool_embs = train_util.encode_prompts_xl(
|
148 |
+
tokenizers,
|
149 |
+
text_encoders,
|
150 |
+
[prompt],
|
151 |
+
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
|
152 |
+
)
|
153 |
+
cache[prompt] = PromptEmbedsXL(
|
154 |
+
tex_embs,
|
155 |
+
pool_embs
|
156 |
+
)
|
157 |
+
|
158 |
+
prompt_pairs.append(
|
159 |
+
PromptEmbedsPair(
|
160 |
+
criteria,
|
161 |
+
cache[settings.target],
|
162 |
+
cache[settings.positive],
|
163 |
+
cache[settings.unconditional],
|
164 |
+
cache[settings.neutral],
|
165 |
+
settings,
|
166 |
+
)
|
167 |
+
)
|
168 |
+
|
169 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
170 |
+
del tokenizer, text_encoder
|
171 |
+
|
172 |
+
flush()
|
173 |
+
|
174 |
+
pbar = tqdm(range(config.train.iterations))
|
175 |
+
|
176 |
+
loss = None
|
177 |
+
|
178 |
+
for i in pbar:
|
179 |
+
with torch.no_grad():
|
180 |
+
noise_scheduler.set_timesteps(
|
181 |
+
config.train.max_denoising_steps, device=device
|
182 |
+
)
|
183 |
+
|
184 |
+
optimizer.zero_grad()
|
185 |
+
|
186 |
+
prompt_pair: PromptEmbedsPair = prompt_pairs[
|
187 |
+
torch.randint(0, len(prompt_pairs), (1,)).item()
|
188 |
+
]
|
189 |
+
|
190 |
+
# 1 ~ 49 ใใใฉใณใใ
|
191 |
+
timesteps_to = torch.randint(
|
192 |
+
1, config.train.max_denoising_steps, (1,)
|
193 |
+
).item()
|
194 |
+
|
195 |
+
height, width = prompt_pair.resolution, prompt_pair.resolution
|
196 |
+
if prompt_pair.dynamic_resolution:
|
197 |
+
height, width = train_util.get_random_resolution_in_bucket(
|
198 |
+
prompt_pair.resolution
|
199 |
+
)
|
200 |
+
|
201 |
+
if config.logging.verbose:
|
202 |
+
print("guidance_scale:", prompt_pair.guidance_scale)
|
203 |
+
print("resolution:", prompt_pair.resolution)
|
204 |
+
print("dynamic_resolution:", prompt_pair.dynamic_resolution)
|
205 |
+
if prompt_pair.dynamic_resolution:
|
206 |
+
print("bucketed resolution:", (height, width))
|
207 |
+
print("batch_size:", prompt_pair.batch_size)
|
208 |
+
print("dynamic_crops:", prompt_pair.dynamic_crops)
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
scale_to_look = abs(random.choice(list(scales_unique)))
|
213 |
+
folder1 = folders[scales==-scale_to_look][0]
|
214 |
+
folder2 = folders[scales==scale_to_look][0]
|
215 |
+
|
216 |
+
ims = os.listdir(f'{folder_main}/{folder1}/')
|
217 |
+
ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
|
218 |
+
random_sampler = random.randint(0, len(ims)-1)
|
219 |
+
|
220 |
+
img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((512,512))
|
221 |
+
img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((512,512))
|
222 |
+
|
223 |
+
seed = random.randint(0,2*15)
|
224 |
+
|
225 |
+
generator = torch.manual_seed(seed)
|
226 |
+
denoised_latents_low, low_noise = train_util.get_noisy_image(
|
227 |
+
img1,
|
228 |
+
vae,
|
229 |
+
generator,
|
230 |
+
unet,
|
231 |
+
noise_scheduler,
|
232 |
+
start_timesteps=0,
|
233 |
+
total_timesteps=timesteps_to)
|
234 |
+
denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
|
235 |
+
low_noise = low_noise.to(device, dtype=weight_dtype)
|
236 |
+
|
237 |
+
generator = torch.manual_seed(seed)
|
238 |
+
denoised_latents_high, high_noise = train_util.get_noisy_image(
|
239 |
+
img2,
|
240 |
+
vae,
|
241 |
+
generator,
|
242 |
+
unet,
|
243 |
+
noise_scheduler,
|
244 |
+
start_timesteps=0,
|
245 |
+
total_timesteps=timesteps_to)
|
246 |
+
denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
|
247 |
+
high_noise = high_noise.to(device, dtype=weight_dtype)
|
248 |
+
noise_scheduler.set_timesteps(1000)
|
249 |
+
|
250 |
+
add_time_ids = train_util.get_add_time_ids(
|
251 |
+
height,
|
252 |
+
width,
|
253 |
+
dynamic_crops=prompt_pair.dynamic_crops,
|
254 |
+
dtype=weight_dtype,
|
255 |
+
).to(device, dtype=weight_dtype)
|
256 |
+
|
257 |
+
|
258 |
+
current_timestep = noise_scheduler.timesteps[
|
259 |
+
int(timesteps_to * 1000 / config.train.max_denoising_steps)
|
260 |
+
]
|
261 |
+
try:
|
262 |
+
# with network: ใฎๅคใงใฏ็ฉบใฎLoRAใฎใฟใๆๅนใซใชใ
|
263 |
+
high_latents = train_util.predict_noise_xl(
|
264 |
+
unet,
|
265 |
+
noise_scheduler,
|
266 |
+
current_timestep,
|
267 |
+
denoised_latents_high,
|
268 |
+
text_embeddings=train_util.concat_embeddings(
|
269 |
+
prompt_pair.unconditional.text_embeds,
|
270 |
+
prompt_pair.positive.text_embeds,
|
271 |
+
prompt_pair.batch_size,
|
272 |
+
),
|
273 |
+
add_text_embeddings=train_util.concat_embeddings(
|
274 |
+
prompt_pair.unconditional.pooled_embeds,
|
275 |
+
prompt_pair.positive.pooled_embeds,
|
276 |
+
prompt_pair.batch_size,
|
277 |
+
),
|
278 |
+
add_time_ids=train_util.concat_embeddings(
|
279 |
+
add_time_ids, add_time_ids, prompt_pair.batch_size
|
280 |
+
),
|
281 |
+
guidance_scale=1,
|
282 |
+
).to(device, dtype=torch.float32)
|
283 |
+
except:
|
284 |
+
flush()
|
285 |
+
print(f'Error Occured!: {np.array(img1).shape} {np.array(img2).shape}')
|
286 |
+
continue
|
287 |
+
# with network: ใฎๅคใงใฏ็ฉบใฎLoRAใฎใฟใๆๅนใซใชใ
|
288 |
+
|
289 |
+
low_latents = train_util.predict_noise_xl(
|
290 |
+
unet,
|
291 |
+
noise_scheduler,
|
292 |
+
current_timestep,
|
293 |
+
denoised_latents_low,
|
294 |
+
text_embeddings=train_util.concat_embeddings(
|
295 |
+
prompt_pair.unconditional.text_embeds,
|
296 |
+
prompt_pair.neutral.text_embeds,
|
297 |
+
prompt_pair.batch_size,
|
298 |
+
),
|
299 |
+
add_text_embeddings=train_util.concat_embeddings(
|
300 |
+
prompt_pair.unconditional.pooled_embeds,
|
301 |
+
prompt_pair.neutral.pooled_embeds,
|
302 |
+
prompt_pair.batch_size,
|
303 |
+
),
|
304 |
+
add_time_ids=train_util.concat_embeddings(
|
305 |
+
add_time_ids, add_time_ids, prompt_pair.batch_size
|
306 |
+
),
|
307 |
+
guidance_scale=1,
|
308 |
+
).to(device, dtype=torch.float32)
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
if config.logging.verbose:
|
313 |
+
print("positive_latents:", positive_latents[0, 0, :5, :5])
|
314 |
+
print("neutral_latents:", neutral_latents[0, 0, :5, :5])
|
315 |
+
print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
|
316 |
+
|
317 |
+
network.set_lora_slider(scale=scale_to_look)
|
318 |
+
with network:
|
319 |
+
target_latents_high = train_util.predict_noise_xl(
|
320 |
+
unet,
|
321 |
+
noise_scheduler,
|
322 |
+
current_timestep,
|
323 |
+
denoised_latents_high,
|
324 |
+
text_embeddings=train_util.concat_embeddings(
|
325 |
+
prompt_pair.unconditional.text_embeds,
|
326 |
+
prompt_pair.positive.text_embeds,
|
327 |
+
prompt_pair.batch_size,
|
328 |
+
),
|
329 |
+
add_text_embeddings=train_util.concat_embeddings(
|
330 |
+
prompt_pair.unconditional.pooled_embeds,
|
331 |
+
prompt_pair.positive.pooled_embeds,
|
332 |
+
prompt_pair.batch_size,
|
333 |
+
),
|
334 |
+
add_time_ids=train_util.concat_embeddings(
|
335 |
+
add_time_ids, add_time_ids, prompt_pair.batch_size
|
336 |
+
),
|
337 |
+
guidance_scale=1,
|
338 |
+
).to(device, dtype=torch.float32)
|
339 |
+
|
340 |
+
high_latents.requires_grad = False
|
341 |
+
low_latents.requires_grad = False
|
342 |
+
|
343 |
+
loss_high = criteria(target_latents_high, high_noise.to(torch.float32))
|
344 |
+
pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
|
345 |
+
loss_high.backward()
|
346 |
+
|
347 |
+
# opposite
|
348 |
+
network.set_lora_slider(scale=-scale_to_look)
|
349 |
+
with network:
|
350 |
+
target_latents_low = train_util.predict_noise_xl(
|
351 |
+
unet,
|
352 |
+
noise_scheduler,
|
353 |
+
current_timestep,
|
354 |
+
denoised_latents_low,
|
355 |
+
text_embeddings=train_util.concat_embeddings(
|
356 |
+
prompt_pair.unconditional.text_embeds,
|
357 |
+
prompt_pair.neutral.text_embeds,
|
358 |
+
prompt_pair.batch_size,
|
359 |
+
),
|
360 |
+
add_text_embeddings=train_util.concat_embeddings(
|
361 |
+
prompt_pair.unconditional.pooled_embeds,
|
362 |
+
prompt_pair.neutral.pooled_embeds,
|
363 |
+
prompt_pair.batch_size,
|
364 |
+
),
|
365 |
+
add_time_ids=train_util.concat_embeddings(
|
366 |
+
add_time_ids, add_time_ids, prompt_pair.batch_size
|
367 |
+
),
|
368 |
+
guidance_scale=1,
|
369 |
+
).to(device, dtype=torch.float32)
|
370 |
+
|
371 |
+
|
372 |
+
high_latents.requires_grad = False
|
373 |
+
low_latents.requires_grad = False
|
374 |
+
|
375 |
+
loss_low = criteria(target_latents_low, low_noise.to(torch.float32))
|
376 |
+
pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
|
377 |
+
loss_low.backward()
|
378 |
+
|
379 |
+
|
380 |
+
optimizer.step()
|
381 |
+
lr_scheduler.step()
|
382 |
+
|
383 |
+
del (
|
384 |
+
high_latents,
|
385 |
+
low_latents,
|
386 |
+
target_latents_low,
|
387 |
+
target_latents_high,
|
388 |
+
)
|
389 |
+
flush()
|
390 |
+
|
391 |
+
if (
|
392 |
+
i % config.save.per_steps == 0
|
393 |
+
and i != 0
|
394 |
+
and i != config.train.iterations - 1
|
395 |
+
):
|
396 |
+
print("Saving...")
|
397 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
398 |
+
network.save_weights(
|
399 |
+
save_path / f"{config.save.name}_{i}steps.pt",
|
400 |
+
dtype=save_weight_dtype,
|
401 |
+
)
|
402 |
+
|
403 |
+
print("Saving...")
|
404 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
405 |
+
network.save_weights(
|
406 |
+
save_path / f"{config.save.name}_last.pt",
|
407 |
+
dtype=save_weight_dtype,
|
408 |
+
)
|
409 |
+
|
410 |
+
del (
|
411 |
+
unet,
|
412 |
+
noise_scheduler,
|
413 |
+
loss,
|
414 |
+
optimizer,
|
415 |
+
network,
|
416 |
+
)
|
417 |
+
|
418 |
+
flush()
|
419 |
+
|
420 |
+
print("Done.")
|
421 |
+
|
422 |
+
|
423 |
+
def main(args):
|
424 |
+
config_file = args.config_file
|
425 |
+
|
426 |
+
config = config_util.load_config_from_yaml(config_file)
|
427 |
+
if args.name is not None:
|
428 |
+
config.save.name = args.name
|
429 |
+
attributes = []
|
430 |
+
if args.attributes is not None:
|
431 |
+
attributes = args.attributes.split(',')
|
432 |
+
attributes = [a.strip() for a in attributes]
|
433 |
+
|
434 |
+
config.network.alpha = args.alpha
|
435 |
+
config.network.rank = args.rank
|
436 |
+
config.save.name += f'_alpha{args.alpha}'
|
437 |
+
config.save.name += f'_rank{config.network.rank }'
|
438 |
+
config.save.name += f'_{config.network.training_method}'
|
439 |
+
config.save.path += f'/{config.save.name}'
|
440 |
+
|
441 |
+
prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
|
442 |
+
|
443 |
+
device = torch.device(f"cuda:{args.device}")
|
444 |
+
|
445 |
+
folders = args.folders.split(',')
|
446 |
+
folders = [f.strip() for f in folders]
|
447 |
+
scales = args.scales.split(',')
|
448 |
+
scales = [f.strip() for f in scales]
|
449 |
+
scales = [int(s) for s in scales]
|
450 |
+
|
451 |
+
print(folders, scales)
|
452 |
+
if len(scales) != len(folders):
|
453 |
+
raise Exception('the number of folders need to match the number of scales')
|
454 |
+
|
455 |
+
if args.stylecheck is not None:
|
456 |
+
check = args.stylecheck.split('-')
|
457 |
+
|
458 |
+
for i in range(int(check[0]), int(check[1])):
|
459 |
+
folder_main = args.folder_main+ f'{i}'
|
460 |
+
config.save.name = f'{os.path.basename(folder_main)}'
|
461 |
+
config.save.name += f'_alpha{args.alpha}'
|
462 |
+
config.save.name += f'_rank{config.network.rank }'
|
463 |
+
config.save.path = f'models/{config.save.name}'
|
464 |
+
train(config=config, prompts=prompts, device=device, folder_main = folder_main, folders = folders, scales = scales)
|
465 |
+
else:
|
466 |
+
train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
|
467 |
+
|
468 |
+
|
469 |
+
if __name__ == "__main__":
|
470 |
+
parser = argparse.ArgumentParser()
|
471 |
+
parser.add_argument(
|
472 |
+
"--config_file",
|
473 |
+
required=True,
|
474 |
+
help="Config file for training.",
|
475 |
+
)
|
476 |
+
# config_file 'data/config.yaml'
|
477 |
+
parser.add_argument(
|
478 |
+
"--alpha",
|
479 |
+
type=float,
|
480 |
+
required=True,
|
481 |
+
help="LoRA weight.",
|
482 |
+
)
|
483 |
+
# --alpha 1.0
|
484 |
+
parser.add_argument(
|
485 |
+
"--rank",
|
486 |
+
type=int,
|
487 |
+
required=False,
|
488 |
+
help="Rank of LoRA.",
|
489 |
+
default=4,
|
490 |
+
)
|
491 |
+
# --rank 4
|
492 |
+
parser.add_argument(
|
493 |
+
"--device",
|
494 |
+
type=int,
|
495 |
+
required=False,
|
496 |
+
default=0,
|
497 |
+
help="Device to train on.",
|
498 |
+
)
|
499 |
+
# --device 0
|
500 |
+
parser.add_argument(
|
501 |
+
"--name",
|
502 |
+
type=str,
|
503 |
+
required=False,
|
504 |
+
default=None,
|
505 |
+
help="Device to train on.",
|
506 |
+
)
|
507 |
+
# --name 'eyesize_slider'
|
508 |
+
parser.add_argument(
|
509 |
+
"--attributes",
|
510 |
+
type=str,
|
511 |
+
required=False,
|
512 |
+
default=None,
|
513 |
+
help="attritbutes to disentangle (comma seperated string)",
|
514 |
+
)
|
515 |
+
parser.add_argument(
|
516 |
+
"--folder_main",
|
517 |
+
type=str,
|
518 |
+
required=True,
|
519 |
+
help="The folder to check",
|
520 |
+
)
|
521 |
+
|
522 |
+
parser.add_argument(
|
523 |
+
"--stylecheck",
|
524 |
+
type=str,
|
525 |
+
required=False,
|
526 |
+
default = None,
|
527 |
+
help="The folder to check",
|
528 |
+
)
|
529 |
+
|
530 |
+
parser.add_argument(
|
531 |
+
"--folders",
|
532 |
+
type=str,
|
533 |
+
required=False,
|
534 |
+
default = 'verylow, low, high, veryhigh',
|
535 |
+
help="folders with different attribute-scaled images",
|
536 |
+
)
|
537 |
+
parser.add_argument(
|
538 |
+
"--scales",
|
539 |
+
type=str,
|
540 |
+
required=False,
|
541 |
+
default = '-2, -1, 1, 2',
|
542 |
+
help="scales for different attribute-scaled images",
|
543 |
+
)
|
544 |
+
|
545 |
+
|
546 |
+
args = parser.parse_args()
|
547 |
+
|
548 |
+
main(args)
|
trainscripts/imagesliders/train_lora-scale.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:
|
2 |
+
# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
|
3 |
+
# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
|
4 |
+
|
5 |
+
from typing import List, Optional
|
6 |
+
import argparse
|
7 |
+
import ast
|
8 |
+
from pathlib import Path
|
9 |
+
import gc
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from tqdm import tqdm
|
13 |
+
import os, glob
|
14 |
+
|
15 |
+
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
|
16 |
+
import train_util
|
17 |
+
import model_util
|
18 |
+
import prompt_util
|
19 |
+
from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings
|
20 |
+
import debug_util
|
21 |
+
import config_util
|
22 |
+
from config_util import RootConfig
|
23 |
+
import random
|
24 |
+
import numpy as np
|
25 |
+
import wandb
|
26 |
+
from PIL import Image
|
27 |
+
|
28 |
+
def flush():
|
29 |
+
torch.cuda.empty_cache()
|
30 |
+
gc.collect()
|
31 |
+
def prev_step(model_output, timestep, scheduler, sample):
|
32 |
+
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
|
33 |
+
alpha_prod_t =scheduler.alphas_cumprod[timestep]
|
34 |
+
alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
|
35 |
+
beta_prod_t = 1 - alpha_prod_t
|
36 |
+
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
37 |
+
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
|
38 |
+
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
|
39 |
+
return prev_sample
|
40 |
+
|
41 |
+
def train(
|
42 |
+
config: RootConfig,
|
43 |
+
prompts: list[PromptSettings],
|
44 |
+
device: int,
|
45 |
+
folder_main: str,
|
46 |
+
folders,
|
47 |
+
scales,
|
48 |
+
):
|
49 |
+
scales = np.array(scales)
|
50 |
+
folders = np.array(folders)
|
51 |
+
scales_unique = list(scales)
|
52 |
+
|
53 |
+
metadata = {
|
54 |
+
"prompts": ",".join([prompt.json() for prompt in prompts]),
|
55 |
+
"config": config.json(),
|
56 |
+
}
|
57 |
+
save_path = Path(config.save.path)
|
58 |
+
|
59 |
+
modules = DEFAULT_TARGET_REPLACE
|
60 |
+
if config.network.type == "c3lier":
|
61 |
+
modules += UNET_TARGET_REPLACE_MODULE_CONV
|
62 |
+
|
63 |
+
if config.logging.verbose:
|
64 |
+
print(metadata)
|
65 |
+
|
66 |
+
if config.logging.use_wandb:
|
67 |
+
wandb.init(project=f"LECO_{config.save.name}", config=metadata)
|
68 |
+
|
69 |
+
weight_dtype = config_util.parse_precision(config.train.precision)
|
70 |
+
save_weight_dtype = config_util.parse_precision(config.train.precision)
|
71 |
+
|
72 |
+
tokenizer, text_encoder, unet, noise_scheduler, vae = model_util.load_models(
|
73 |
+
config.pretrained_model.name_or_path,
|
74 |
+
scheduler_name=config.train.noise_scheduler,
|
75 |
+
v2=config.pretrained_model.v2,
|
76 |
+
v_pred=config.pretrained_model.v_pred,
|
77 |
+
)
|
78 |
+
|
79 |
+
text_encoder.to(device, dtype=weight_dtype)
|
80 |
+
text_encoder.eval()
|
81 |
+
|
82 |
+
unet.to(device, dtype=weight_dtype)
|
83 |
+
unet.enable_xformers_memory_efficient_attention()
|
84 |
+
unet.requires_grad_(False)
|
85 |
+
unet.eval()
|
86 |
+
|
87 |
+
vae.to(device)
|
88 |
+
vae.requires_grad_(False)
|
89 |
+
vae.eval()
|
90 |
+
|
91 |
+
network = LoRANetwork(
|
92 |
+
unet,
|
93 |
+
rank=config.network.rank,
|
94 |
+
multiplier=1.0,
|
95 |
+
alpha=config.network.alpha,
|
96 |
+
train_method=config.network.training_method,
|
97 |
+
).to(device, dtype=weight_dtype)
|
98 |
+
|
99 |
+
optimizer_module = train_util.get_optimizer(config.train.optimizer)
|
100 |
+
#optimizer_args
|
101 |
+
optimizer_kwargs = {}
|
102 |
+
if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
|
103 |
+
for arg in config.train.optimizer_args.split(" "):
|
104 |
+
key, value = arg.split("=")
|
105 |
+
value = ast.literal_eval(value)
|
106 |
+
optimizer_kwargs[key] = value
|
107 |
+
|
108 |
+
optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
|
109 |
+
lr_scheduler = train_util.get_lr_scheduler(
|
110 |
+
config.train.lr_scheduler,
|
111 |
+
optimizer,
|
112 |
+
max_iterations=config.train.iterations,
|
113 |
+
lr_min=config.train.lr / 100,
|
114 |
+
)
|
115 |
+
criteria = torch.nn.MSELoss()
|
116 |
+
|
117 |
+
print("Prompts")
|
118 |
+
for settings in prompts:
|
119 |
+
print(settings)
|
120 |
+
|
121 |
+
# debug
|
122 |
+
debug_util.check_requires_grad(network)
|
123 |
+
debug_util.check_training_mode(network)
|
124 |
+
|
125 |
+
cache = PromptEmbedsCache()
|
126 |
+
prompt_pairs: list[PromptEmbedsPair] = []
|
127 |
+
|
128 |
+
with torch.no_grad():
|
129 |
+
for settings in prompts:
|
130 |
+
print(settings)
|
131 |
+
for prompt in [
|
132 |
+
settings.target,
|
133 |
+
settings.positive,
|
134 |
+
settings.neutral,
|
135 |
+
settings.unconditional,
|
136 |
+
]:
|
137 |
+
print(prompt)
|
138 |
+
if isinstance(prompt, list):
|
139 |
+
if prompt == settings.positive:
|
140 |
+
key_setting = 'positive'
|
141 |
+
else:
|
142 |
+
key_setting = 'attributes'
|
143 |
+
if len(prompt) == 0:
|
144 |
+
cache[key_setting] = []
|
145 |
+
else:
|
146 |
+
if cache[key_setting] is None:
|
147 |
+
cache[key_setting] = train_util.encode_prompts(
|
148 |
+
tokenizer, text_encoder, prompt
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
if cache[prompt] == None:
|
152 |
+
cache[prompt] = train_util.encode_prompts(
|
153 |
+
tokenizer, text_encoder, [prompt]
|
154 |
+
)
|
155 |
+
|
156 |
+
prompt_pairs.append(
|
157 |
+
PromptEmbedsPair(
|
158 |
+
criteria,
|
159 |
+
cache[settings.target],
|
160 |
+
cache[settings.positive],
|
161 |
+
cache[settings.unconditional],
|
162 |
+
cache[settings.neutral],
|
163 |
+
settings,
|
164 |
+
)
|
165 |
+
)
|
166 |
+
|
167 |
+
del tokenizer
|
168 |
+
del text_encoder
|
169 |
+
|
170 |
+
flush()
|
171 |
+
|
172 |
+
pbar = tqdm(range(config.train.iterations))
|
173 |
+
for i in pbar:
|
174 |
+
with torch.no_grad():
|
175 |
+
noise_scheduler.set_timesteps(
|
176 |
+
config.train.max_denoising_steps, device=device
|
177 |
+
)
|
178 |
+
|
179 |
+
optimizer.zero_grad()
|
180 |
+
|
181 |
+
prompt_pair: PromptEmbedsPair = prompt_pairs[
|
182 |
+
torch.randint(0, len(prompt_pairs), (1,)).item()
|
183 |
+
]
|
184 |
+
|
185 |
+
# 1 ~ 49 ใใใฉใณใใ
|
186 |
+
timesteps_to = torch.randint(
|
187 |
+
1, config.train.max_denoising_steps-1, (1,)
|
188 |
+
# 1, 25, (1,)
|
189 |
+
).item()
|
190 |
+
|
191 |
+
height, width = (
|
192 |
+
prompt_pair.resolution,
|
193 |
+
prompt_pair.resolution,
|
194 |
+
)
|
195 |
+
if prompt_pair.dynamic_resolution:
|
196 |
+
height, width = train_util.get_random_resolution_in_bucket(
|
197 |
+
prompt_pair.resolution
|
198 |
+
)
|
199 |
+
|
200 |
+
if config.logging.verbose:
|
201 |
+
print("guidance_scale:", prompt_pair.guidance_scale)
|
202 |
+
print("resolution:", prompt_pair.resolution)
|
203 |
+
print("dynamic_resolution:", prompt_pair.dynamic_resolution)
|
204 |
+
if prompt_pair.dynamic_resolution:
|
205 |
+
print("bucketed resolution:", (height, width))
|
206 |
+
print("batch_size:", prompt_pair.batch_size)
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
scale_to_look = abs(random.choice(list(scales_unique)))
|
212 |
+
folder1 = folders[scales==-scale_to_look][0]
|
213 |
+
folder2 = folders[scales==scale_to_look][0]
|
214 |
+
|
215 |
+
ims = os.listdir(f'{folder_main}/{folder1}/')
|
216 |
+
ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
|
217 |
+
random_sampler = random.randint(0, len(ims)-1)
|
218 |
+
|
219 |
+
img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((256,256))
|
220 |
+
img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((256,256))
|
221 |
+
|
222 |
+
seed = random.randint(0,2*15)
|
223 |
+
|
224 |
+
generator = torch.manual_seed(seed)
|
225 |
+
denoised_latents_low, low_noise = train_util.get_noisy_image(
|
226 |
+
img1,
|
227 |
+
vae,
|
228 |
+
generator,
|
229 |
+
unet,
|
230 |
+
noise_scheduler,
|
231 |
+
start_timesteps=0,
|
232 |
+
total_timesteps=timesteps_to)
|
233 |
+
denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
|
234 |
+
low_noise = low_noise.to(device, dtype=weight_dtype)
|
235 |
+
|
236 |
+
generator = torch.manual_seed(seed)
|
237 |
+
denoised_latents_high, high_noise = train_util.get_noisy_image(
|
238 |
+
img2,
|
239 |
+
vae,
|
240 |
+
generator,
|
241 |
+
unet,
|
242 |
+
noise_scheduler,
|
243 |
+
start_timesteps=0,
|
244 |
+
total_timesteps=timesteps_to)
|
245 |
+
denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
|
246 |
+
high_noise = high_noise.to(device, dtype=weight_dtype)
|
247 |
+
noise_scheduler.set_timesteps(1000)
|
248 |
+
|
249 |
+
current_timestep = noise_scheduler.timesteps[
|
250 |
+
int(timesteps_to * 1000 / config.train.max_denoising_steps)
|
251 |
+
]
|
252 |
+
|
253 |
+
# with network: ใฎๅคใงใฏ็ฉบใฎLoRAใฎใฟใๆๅนใซใชใ
|
254 |
+
high_latents = train_util.predict_noise(
|
255 |
+
unet,
|
256 |
+
noise_scheduler,
|
257 |
+
current_timestep,
|
258 |
+
denoised_latents_high,
|
259 |
+
train_util.concat_embeddings(
|
260 |
+
prompt_pair.unconditional,
|
261 |
+
prompt_pair.positive,
|
262 |
+
prompt_pair.batch_size,
|
263 |
+
),
|
264 |
+
guidance_scale=1,
|
265 |
+
).to("cpu", dtype=torch.float32)
|
266 |
+
# with network: ใฎๅคใงใฏ็ฉบใฎLoRAใฎใฟใๆๅนใซใชใ
|
267 |
+
low_latents = train_util.predict_noise(
|
268 |
+
unet,
|
269 |
+
noise_scheduler,
|
270 |
+
current_timestep,
|
271 |
+
denoised_latents_low,
|
272 |
+
train_util.concat_embeddings(
|
273 |
+
prompt_pair.unconditional,
|
274 |
+
prompt_pair.unconditional,
|
275 |
+
prompt_pair.batch_size,
|
276 |
+
),
|
277 |
+
guidance_scale=1,
|
278 |
+
).to("cpu", dtype=torch.float32)
|
279 |
+
if config.logging.verbose:
|
280 |
+
print("positive_latents:", positive_latents[0, 0, :5, :5])
|
281 |
+
print("neutral_latents:", neutral_latents[0, 0, :5, :5])
|
282 |
+
print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
|
283 |
+
|
284 |
+
network.set_lora_slider(scale=scale_to_look)
|
285 |
+
with network:
|
286 |
+
target_latents_high = train_util.predict_noise(
|
287 |
+
unet,
|
288 |
+
noise_scheduler,
|
289 |
+
current_timestep,
|
290 |
+
denoised_latents_high,
|
291 |
+
train_util.concat_embeddings(
|
292 |
+
prompt_pair.unconditional,
|
293 |
+
prompt_pair.positive,
|
294 |
+
prompt_pair.batch_size,
|
295 |
+
),
|
296 |
+
guidance_scale=1,
|
297 |
+
).to("cpu", dtype=torch.float32)
|
298 |
+
|
299 |
+
|
300 |
+
high_latents.requires_grad = False
|
301 |
+
low_latents.requires_grad = False
|
302 |
+
|
303 |
+
loss_high = criteria(target_latents_high, high_noise.cpu().to(torch.float32))
|
304 |
+
pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
|
305 |
+
loss_high.backward()
|
306 |
+
|
307 |
+
|
308 |
+
network.set_lora_slider(scale=-scale_to_look)
|
309 |
+
with network:
|
310 |
+
target_latents_low = train_util.predict_noise(
|
311 |
+
unet,
|
312 |
+
noise_scheduler,
|
313 |
+
current_timestep,
|
314 |
+
denoised_latents_low,
|
315 |
+
train_util.concat_embeddings(
|
316 |
+
prompt_pair.unconditional,
|
317 |
+
prompt_pair.neutral,
|
318 |
+
prompt_pair.batch_size,
|
319 |
+
),
|
320 |
+
guidance_scale=1,
|
321 |
+
).to("cpu", dtype=torch.float32)
|
322 |
+
|
323 |
+
|
324 |
+
high_latents.requires_grad = False
|
325 |
+
low_latents.requires_grad = False
|
326 |
+
|
327 |
+
loss_low = criteria(target_latents_low, low_noise.cpu().to(torch.float32))
|
328 |
+
pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
|
329 |
+
loss_low.backward()
|
330 |
+
|
331 |
+
## NOTICE NO zero_grad between these steps (accumulating gradients)
|
332 |
+
#following guidelines from Ostris (https://github.com/ostris/ai-toolkit)
|
333 |
+
|
334 |
+
optimizer.step()
|
335 |
+
lr_scheduler.step()
|
336 |
+
|
337 |
+
del (
|
338 |
+
high_latents,
|
339 |
+
low_latents,
|
340 |
+
target_latents_low,
|
341 |
+
target_latents_high,
|
342 |
+
)
|
343 |
+
flush()
|
344 |
+
|
345 |
+
if (
|
346 |
+
i % config.save.per_steps == 0
|
347 |
+
and i != 0
|
348 |
+
and i != config.train.iterations - 1
|
349 |
+
):
|
350 |
+
print("Saving...")
|
351 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
352 |
+
network.save_weights(
|
353 |
+
save_path / f"{config.save.name}_{i}steps.pt",
|
354 |
+
dtype=save_weight_dtype,
|
355 |
+
)
|
356 |
+
|
357 |
+
print("Saving...")
|
358 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
359 |
+
network.save_weights(
|
360 |
+
save_path / f"{config.save.name}_last.pt",
|
361 |
+
dtype=save_weight_dtype,
|
362 |
+
)
|
363 |
+
|
364 |
+
del (
|
365 |
+
unet,
|
366 |
+
noise_scheduler,
|
367 |
+
optimizer,
|
368 |
+
network,
|
369 |
+
)
|
370 |
+
|
371 |
+
flush()
|
372 |
+
|
373 |
+
print("Done.")
|
374 |
+
|
375 |
+
|
376 |
+
def main(args):
|
377 |
+
config_file = args.config_file
|
378 |
+
|
379 |
+
config = config_util.load_config_from_yaml(config_file)
|
380 |
+
if args.name is not None:
|
381 |
+
config.save.name = args.name
|
382 |
+
attributes = []
|
383 |
+
if args.attributes is not None:
|
384 |
+
attributes = args.attributes.split(',')
|
385 |
+
attributes = [a.strip() for a in attributes]
|
386 |
+
|
387 |
+
config.network.alpha = args.alpha
|
388 |
+
config.network.rank = args.rank
|
389 |
+
config.save.name += f'_alpha{args.alpha}'
|
390 |
+
config.save.name += f'_rank{config.network.rank }'
|
391 |
+
config.save.name += f'_{config.network.training_method}'
|
392 |
+
config.save.path += f'/{config.save.name}'
|
393 |
+
|
394 |
+
prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
|
395 |
+
device = torch.device(f"cuda:{args.device}")
|
396 |
+
|
397 |
+
|
398 |
+
folders = args.folders.split(',')
|
399 |
+
folders = [f.strip() for f in folders]
|
400 |
+
scales = args.scales.split(',')
|
401 |
+
scales = [f.strip() for f in scales]
|
402 |
+
scales = [int(s) for s in scales]
|
403 |
+
|
404 |
+
print(folders, scales)
|
405 |
+
if len(scales) != len(folders):
|
406 |
+
raise Exception('the number of folders need to match the number of scales')
|
407 |
+
|
408 |
+
if args.stylecheck is not None:
|
409 |
+
check = args.stylecheck.split('-')
|
410 |
+
|
411 |
+
for i in range(int(check[0]), int(check[1])):
|
412 |
+
folder_main = args.folder_main+ f'{i}'
|
413 |
+
config.save.name = f'{os.path.basename(folder_main)}'
|
414 |
+
config.save.name += f'_alpha{args.alpha}'
|
415 |
+
config.save.name += f'_rank{config.network.rank }'
|
416 |
+
config.save.path = f'models/{config.save.name}'
|
417 |
+
train(config=config, prompts=prompts, device=device, folder_main = folder_main)
|
418 |
+
else:
|
419 |
+
train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
|
420 |
+
|
421 |
+
if __name__ == "__main__":
|
422 |
+
parser = argparse.ArgumentParser()
|
423 |
+
parser.add_argument(
|
424 |
+
"--config_file",
|
425 |
+
required=False,
|
426 |
+
default = 'data/config.yaml',
|
427 |
+
help="Config file for training.",
|
428 |
+
)
|
429 |
+
parser.add_argument(
|
430 |
+
"--alpha",
|
431 |
+
type=float,
|
432 |
+
required=True,
|
433 |
+
help="LoRA weight.",
|
434 |
+
)
|
435 |
+
|
436 |
+
parser.add_argument(
|
437 |
+
"--rank",
|
438 |
+
type=int,
|
439 |
+
required=False,
|
440 |
+
help="Rank of LoRA.",
|
441 |
+
default=4,
|
442 |
+
)
|
443 |
+
|
444 |
+
parser.add_argument(
|
445 |
+
"--device",
|
446 |
+
type=int,
|
447 |
+
required=False,
|
448 |
+
default=0,
|
449 |
+
help="Device to train on.",
|
450 |
+
)
|
451 |
+
|
452 |
+
parser.add_argument(
|
453 |
+
"--name",
|
454 |
+
type=str,
|
455 |
+
required=False,
|
456 |
+
default=None,
|
457 |
+
help="Device to train on.",
|
458 |
+
)
|
459 |
+
|
460 |
+
parser.add_argument(
|
461 |
+
"--attributes",
|
462 |
+
type=str,
|
463 |
+
required=False,
|
464 |
+
default=None,
|
465 |
+
help="attritbutes to disentangle",
|
466 |
+
)
|
467 |
+
|
468 |
+
parser.add_argument(
|
469 |
+
"--folder_main",
|
470 |
+
type=str,
|
471 |
+
required=True,
|
472 |
+
help="The folder to check",
|
473 |
+
)
|
474 |
+
|
475 |
+
parser.add_argument(
|
476 |
+
"--stylecheck",
|
477 |
+
type=str,
|
478 |
+
required=False,
|
479 |
+
default = None,
|
480 |
+
help="The folder to check",
|
481 |
+
)
|
482 |
+
|
483 |
+
parser.add_argument(
|
484 |
+
"--folders",
|
485 |
+
type=str,
|
486 |
+
required=False,
|
487 |
+
default = 'verylow, low, high, veryhigh',
|
488 |
+
help="folders with different attribute-scaled images",
|
489 |
+
)
|
490 |
+
parser.add_argument(
|
491 |
+
"--scales",
|
492 |
+
type=str,
|
493 |
+
required=False,
|
494 |
+
default = '-2, -1,1, 2',
|
495 |
+
help="scales for different attribute-scaled images",
|
496 |
+
)
|
497 |
+
|
498 |
+
|
499 |
+
args = parser.parse_args()
|
500 |
+
|
501 |
+
main(args)
|
trainscripts/imagesliders/train_util.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
6 |
+
from diffusers import UNet2DConditionModel, SchedulerMixin
|
7 |
+
from diffusers.image_processor import VaeImageProcessor
|
8 |
+
from model_util import SDXL_TEXT_ENCODER_TYPE
|
9 |
+
from diffusers.utils import randn_tensor
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
UNET_IN_CHANNELS = 4 # Stable Diffusion ใฎ in_channels ใฏ 4 ใงๅบๅฎใXLใๅใใ
|
14 |
+
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
15 |
+
|
16 |
+
UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
|
17 |
+
TEXT_ENCODER_2_PROJECTION_DIM = 1280
|
18 |
+
UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
|
19 |
+
|
20 |
+
|
21 |
+
def get_random_noise(
|
22 |
+
batch_size: int, height: int, width: int, generator: torch.Generator = None
|
23 |
+
) -> torch.Tensor:
|
24 |
+
return torch.randn(
|
25 |
+
(
|
26 |
+
batch_size,
|
27 |
+
UNET_IN_CHANNELS,
|
28 |
+
height // VAE_SCALE_FACTOR, # ็ธฆใจๆจชใใใงใใฃใฆใใฎใใใใใชใใใฉใใฉใฃใกใซใใๅคงใใชๅ้กใฏ็บ็ใใชใใฎใงใใใงใใใ
|
29 |
+
width // VAE_SCALE_FACTOR,
|
30 |
+
),
|
31 |
+
generator=generator,
|
32 |
+
device="cpu",
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
# https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
37 |
+
def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
|
38 |
+
latents = latents + noise_offset * torch.randn(
|
39 |
+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
40 |
+
)
|
41 |
+
return latents
|
42 |
+
|
43 |
+
|
44 |
+
def get_initial_latents(
|
45 |
+
scheduler: SchedulerMixin,
|
46 |
+
n_imgs: int,
|
47 |
+
height: int,
|
48 |
+
width: int,
|
49 |
+
n_prompts: int,
|
50 |
+
generator=None,
|
51 |
+
) -> torch.Tensor:
|
52 |
+
noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
|
53 |
+
n_prompts, 1, 1, 1
|
54 |
+
)
|
55 |
+
|
56 |
+
latents = noise * scheduler.init_noise_sigma
|
57 |
+
|
58 |
+
return latents
|
59 |
+
|
60 |
+
|
61 |
+
def text_tokenize(
|
62 |
+
tokenizer: CLIPTokenizer, # ๆฎ้ใชใใฒใจใคใXLใชใใตใใค๏ผ
|
63 |
+
prompts: list[str],
|
64 |
+
):
|
65 |
+
return tokenizer(
|
66 |
+
prompts,
|
67 |
+
padding="max_length",
|
68 |
+
max_length=tokenizer.model_max_length,
|
69 |
+
truncation=True,
|
70 |
+
return_tensors="pt",
|
71 |
+
).input_ids
|
72 |
+
|
73 |
+
|
74 |
+
def text_encode(text_encoder: CLIPTextModel, tokens):
|
75 |
+
return text_encoder(tokens.to(text_encoder.device))[0]
|
76 |
+
|
77 |
+
|
78 |
+
def encode_prompts(
|
79 |
+
tokenizer: CLIPTokenizer,
|
80 |
+
text_encoder: CLIPTokenizer,
|
81 |
+
prompts: list[str],
|
82 |
+
):
|
83 |
+
|
84 |
+
text_tokens = text_tokenize(tokenizer, prompts)
|
85 |
+
text_embeddings = text_encode(text_encoder, text_tokens)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
return text_embeddings
|
90 |
+
|
91 |
+
|
92 |
+
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
|
93 |
+
def text_encode_xl(
|
94 |
+
text_encoder: SDXL_TEXT_ENCODER_TYPE,
|
95 |
+
tokens: torch.FloatTensor,
|
96 |
+
num_images_per_prompt: int = 1,
|
97 |
+
):
|
98 |
+
prompt_embeds = text_encoder(
|
99 |
+
tokens.to(text_encoder.device), output_hidden_states=True
|
100 |
+
)
|
101 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
102 |
+
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
103 |
+
|
104 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
105 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
106 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
107 |
+
|
108 |
+
return prompt_embeds, pooled_prompt_embeds
|
109 |
+
|
110 |
+
|
111 |
+
def encode_prompts_xl(
|
112 |
+
tokenizers: list[CLIPTokenizer],
|
113 |
+
text_encoders: list[SDXL_TEXT_ENCODER_TYPE],
|
114 |
+
prompts: list[str],
|
115 |
+
num_images_per_prompt: int = 1,
|
116 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
117 |
+
# text_encoder and text_encoder_2's penuultimate layer's output
|
118 |
+
text_embeds_list = []
|
119 |
+
pooled_text_embeds = None # always text_encoder_2's pool
|
120 |
+
|
121 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
122 |
+
text_tokens_input_ids = text_tokenize(tokenizer, prompts)
|
123 |
+
text_embeds, pooled_text_embeds = text_encode_xl(
|
124 |
+
text_encoder, text_tokens_input_ids, num_images_per_prompt
|
125 |
+
)
|
126 |
+
|
127 |
+
text_embeds_list.append(text_embeds)
|
128 |
+
|
129 |
+
bs_embed = pooled_text_embeds.shape[0]
|
130 |
+
pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
|
131 |
+
bs_embed * num_images_per_prompt, -1
|
132 |
+
)
|
133 |
+
|
134 |
+
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
135 |
+
|
136 |
+
|
137 |
+
def concat_embeddings(
|
138 |
+
unconditional: torch.FloatTensor,
|
139 |
+
conditional: torch.FloatTensor,
|
140 |
+
n_imgs: int,
|
141 |
+
):
|
142 |
+
return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
|
143 |
+
|
144 |
+
|
145 |
+
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721
|
146 |
+
def predict_noise(
|
147 |
+
unet: UNet2DConditionModel,
|
148 |
+
scheduler: SchedulerMixin,
|
149 |
+
timestep: int, # ็พๅจใฎใฟใคใ ในใใใ
|
150 |
+
latents: torch.FloatTensor,
|
151 |
+
text_embeddings: torch.FloatTensor, # uncond ใช text embed ใจ cond ใช text embed ใ็ตๅใใใใฎ
|
152 |
+
guidance_scale=7.5,
|
153 |
+
) -> torch.FloatTensor:
|
154 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
155 |
+
latent_model_input = torch.cat([latents] * 2)
|
156 |
+
|
157 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
158 |
+
|
159 |
+
# predict the noise residual
|
160 |
+
noise_pred = unet(
|
161 |
+
latent_model_input,
|
162 |
+
timestep,
|
163 |
+
encoder_hidden_states=text_embeddings,
|
164 |
+
).sample
|
165 |
+
|
166 |
+
# perform guidance
|
167 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
168 |
+
guided_target = noise_pred_uncond + guidance_scale * (
|
169 |
+
noise_pred_text - noise_pred_uncond
|
170 |
+
)
|
171 |
+
|
172 |
+
return guided_target
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
177 |
+
@torch.no_grad()
|
178 |
+
def diffusion(
|
179 |
+
unet: UNet2DConditionModel,
|
180 |
+
scheduler: SchedulerMixin,
|
181 |
+
latents: torch.FloatTensor, # ใใ ใฎใใคใบใ ใใฎlatents
|
182 |
+
text_embeddings: torch.FloatTensor,
|
183 |
+
total_timesteps: int = 1000,
|
184 |
+
start_timesteps=0,
|
185 |
+
**kwargs,
|
186 |
+
):
|
187 |
+
# latents_steps = []
|
188 |
+
|
189 |
+
for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
|
190 |
+
noise_pred = predict_noise(
|
191 |
+
unet, scheduler, timestep, latents, text_embeddings, **kwargs
|
192 |
+
)
|
193 |
+
|
194 |
+
# compute the previous noisy sample x_t -> x_t-1
|
195 |
+
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
196 |
+
|
197 |
+
# return latents_steps
|
198 |
+
return latents
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def get_noisy_image(
|
202 |
+
img,
|
203 |
+
vae,
|
204 |
+
generator,
|
205 |
+
unet: UNet2DConditionModel,
|
206 |
+
scheduler: SchedulerMixin,
|
207 |
+
total_timesteps: int = 1000,
|
208 |
+
start_timesteps=0,
|
209 |
+
|
210 |
+
**kwargs,
|
211 |
+
):
|
212 |
+
# latents_steps = []
|
213 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
214 |
+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
215 |
+
|
216 |
+
image = img
|
217 |
+
im_orig = image
|
218 |
+
device = vae.device
|
219 |
+
image = image_processor.preprocess(image).to(device)
|
220 |
+
|
221 |
+
init_latents = vae.encode(image).latent_dist.sample(None)
|
222 |
+
init_latents = vae.config.scaling_factor * init_latents
|
223 |
+
|
224 |
+
init_latents = torch.cat([init_latents], dim=0)
|
225 |
+
|
226 |
+
shape = init_latents.shape
|
227 |
+
|
228 |
+
noise = randn_tensor(shape, generator=generator, device=device)
|
229 |
+
|
230 |
+
time_ = total_timesteps
|
231 |
+
timestep = scheduler.timesteps[time_:time_+1]
|
232 |
+
# get latents
|
233 |
+
init_latents = scheduler.add_noise(init_latents, noise, timestep)
|
234 |
+
|
235 |
+
return init_latents, noise
|
236 |
+
|
237 |
+
|
238 |
+
def rescale_noise_cfg(
|
239 |
+
noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
|
240 |
+
):
|
241 |
+
"""
|
242 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
243 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
244 |
+
"""
|
245 |
+
std_text = noise_pred_text.std(
|
246 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
|
247 |
+
)
|
248 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
249 |
+
# rescale the results from guidance (fixes overexposure)
|
250 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
251 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
252 |
+
noise_cfg = (
|
253 |
+
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
254 |
+
)
|
255 |
+
|
256 |
+
return noise_cfg
|
257 |
+
|
258 |
+
|
259 |
+
def predict_noise_xl(
|
260 |
+
unet: UNet2DConditionModel,
|
261 |
+
scheduler: SchedulerMixin,
|
262 |
+
timestep: int, # ็พๅจใฎใฟใคใ ในใใใ
|
263 |
+
latents: torch.FloatTensor,
|
264 |
+
text_embeddings: torch.FloatTensor, # uncond ใช text embed ใจ cond ใช text embed ใ็ตๅใใใใฎ
|
265 |
+
add_text_embeddings: torch.FloatTensor, # pooled ใชใใค
|
266 |
+
add_time_ids: torch.FloatTensor,
|
267 |
+
guidance_scale=7.5,
|
268 |
+
guidance_rescale=0.7,
|
269 |
+
) -> torch.FloatTensor:
|
270 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
271 |
+
latent_model_input = torch.cat([latents] * 2)
|
272 |
+
|
273 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
274 |
+
|
275 |
+
added_cond_kwargs = {
|
276 |
+
"text_embeds": add_text_embeddings,
|
277 |
+
"time_ids": add_time_ids,
|
278 |
+
}
|
279 |
+
|
280 |
+
# predict the noise residual
|
281 |
+
noise_pred = unet(
|
282 |
+
latent_model_input,
|
283 |
+
timestep,
|
284 |
+
encoder_hidden_states=text_embeddings,
|
285 |
+
added_cond_kwargs=added_cond_kwargs,
|
286 |
+
).sample
|
287 |
+
|
288 |
+
# perform guidance
|
289 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
290 |
+
guided_target = noise_pred_uncond + guidance_scale * (
|
291 |
+
noise_pred_text - noise_pred_uncond
|
292 |
+
)
|
293 |
+
|
294 |
+
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
295 |
+
noise_pred = rescale_noise_cfg(
|
296 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
297 |
+
)
|
298 |
+
|
299 |
+
return guided_target
|
300 |
+
|
301 |
+
|
302 |
+
@torch.no_grad()
|
303 |
+
def diffusion_xl(
|
304 |
+
unet: UNet2DConditionModel,
|
305 |
+
scheduler: SchedulerMixin,
|
306 |
+
latents: torch.FloatTensor, # ใใ ใฎใใคใบใ ใใฎlatents
|
307 |
+
text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
|
308 |
+
add_text_embeddings: torch.FloatTensor, # pooled ใชใใค
|
309 |
+
add_time_ids: torch.FloatTensor,
|
310 |
+
guidance_scale: float = 1.0,
|
311 |
+
total_timesteps: int = 1000,
|
312 |
+
start_timesteps=0,
|
313 |
+
):
|
314 |
+
# latents_steps = []
|
315 |
+
|
316 |
+
for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
|
317 |
+
noise_pred = predict_noise_xl(
|
318 |
+
unet,
|
319 |
+
scheduler,
|
320 |
+
timestep,
|
321 |
+
latents,
|
322 |
+
text_embeddings,
|
323 |
+
add_text_embeddings,
|
324 |
+
add_time_ids,
|
325 |
+
guidance_scale=guidance_scale,
|
326 |
+
guidance_rescale=0.7,
|
327 |
+
)
|
328 |
+
|
329 |
+
# compute the previous noisy sample x_t -> x_t-1
|
330 |
+
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
331 |
+
|
332 |
+
# return latents_steps
|
333 |
+
return latents
|
334 |
+
|
335 |
+
|
336 |
+
# for XL
|
337 |
+
def get_add_time_ids(
|
338 |
+
height: int,
|
339 |
+
width: int,
|
340 |
+
dynamic_crops: bool = False,
|
341 |
+
dtype: torch.dtype = torch.float32,
|
342 |
+
):
|
343 |
+
if dynamic_crops:
|
344 |
+
# random float scale between 1 and 3
|
345 |
+
random_scale = torch.rand(1).item() * 2 + 1
|
346 |
+
original_size = (int(height * random_scale), int(width * random_scale))
|
347 |
+
# random position
|
348 |
+
crops_coords_top_left = (
|
349 |
+
torch.randint(0, original_size[0] - height, (1,)).item(),
|
350 |
+
torch.randint(0, original_size[1] - width, (1,)).item(),
|
351 |
+
)
|
352 |
+
target_size = (height, width)
|
353 |
+
else:
|
354 |
+
original_size = (height, width)
|
355 |
+
crops_coords_top_left = (0, 0)
|
356 |
+
target_size = (height, width)
|
357 |
+
|
358 |
+
# this is expected as 6
|
359 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
360 |
+
|
361 |
+
# this is expected as 2816
|
362 |
+
passed_add_embed_dim = (
|
363 |
+
UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
|
364 |
+
+ TEXT_ENCODER_2_PROJECTION_DIM # + 1280
|
365 |
+
)
|
366 |
+
if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
|
367 |
+
raise ValueError(
|
368 |
+
f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
369 |
+
)
|
370 |
+
|
371 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
372 |
+
return add_time_ids
|
373 |
+
|
374 |
+
|
375 |
+
def get_optimizer(name: str):
|
376 |
+
name = name.lower()
|
377 |
+
|
378 |
+
if name.startswith("dadapt"):
|
379 |
+
import dadaptation
|
380 |
+
|
381 |
+
if name == "dadaptadam":
|
382 |
+
return dadaptation.DAdaptAdam
|
383 |
+
elif name == "dadaptlion":
|
384 |
+
return dadaptation.DAdaptLion
|
385 |
+
else:
|
386 |
+
raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
|
387 |
+
|
388 |
+
elif name.endswith("8bit"): # ๆค่จผใใฆใชใ
|
389 |
+
import bitsandbytes as bnb
|
390 |
+
|
391 |
+
if name == "adam8bit":
|
392 |
+
return bnb.optim.Adam8bit
|
393 |
+
elif name == "lion8bit":
|
394 |
+
return bnb.optim.Lion8bit
|
395 |
+
else:
|
396 |
+
raise ValueError("8bit optimizer must be adam8bit or lion8bit")
|
397 |
+
|
398 |
+
else:
|
399 |
+
if name == "adam":
|
400 |
+
return torch.optim.Adam
|
401 |
+
elif name == "adamw":
|
402 |
+
return torch.optim.AdamW
|
403 |
+
elif name == "lion":
|
404 |
+
from lion_pytorch import Lion
|
405 |
+
|
406 |
+
return Lion
|
407 |
+
elif name == "prodigy":
|
408 |
+
import prodigyopt
|
409 |
+
|
410 |
+
return prodigyopt.Prodigy
|
411 |
+
else:
|
412 |
+
raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
|
413 |
+
|
414 |
+
|
415 |
+
def get_lr_scheduler(
|
416 |
+
name: Optional[str],
|
417 |
+
optimizer: torch.optim.Optimizer,
|
418 |
+
max_iterations: Optional[int],
|
419 |
+
lr_min: Optional[float],
|
420 |
+
**kwargs,
|
421 |
+
):
|
422 |
+
if name == "cosine":
|
423 |
+
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
424 |
+
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
|
425 |
+
)
|
426 |
+
elif name == "cosine_with_restarts":
|
427 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
428 |
+
optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
|
429 |
+
)
|
430 |
+
elif name == "step":
|
431 |
+
return torch.optim.lr_scheduler.StepLR(
|
432 |
+
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
|
433 |
+
)
|
434 |
+
elif name == "constant":
|
435 |
+
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
|
436 |
+
elif name == "linear":
|
437 |
+
return torch.optim.lr_scheduler.LinearLR(
|
438 |
+
optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
|
439 |
+
)
|
440 |
+
else:
|
441 |
+
raise ValueError(
|
442 |
+
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
|
443 |
+
)
|
444 |
+
|
445 |
+
|
446 |
+
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
|
447 |
+
max_resolution = bucket_resolution
|
448 |
+
min_resolution = bucket_resolution // 2
|
449 |
+
|
450 |
+
step = 64
|
451 |
+
|
452 |
+
min_step = min_resolution // step
|
453 |
+
max_step = max_resolution // step
|
454 |
+
|
455 |
+
height = torch.randint(min_step, max_step, (1,)).item() * step
|
456 |
+
width = torch.randint(min_step, max_step, (1,)).item() * step
|
457 |
+
|
458 |
+
return height, width
|
trainscripts/textsliders/__init__.py
ADDED
File without changes
|
trainscripts/textsliders/config_util.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from lora import TRAINING_METHODS
|
9 |
+
|
10 |
+
PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
|
11 |
+
NETWORK_TYPES = Literal["lierla", "c3lier"]
|
12 |
+
|
13 |
+
|
14 |
+
class PretrainedModelConfig(BaseModel):
|
15 |
+
name_or_path: str
|
16 |
+
v2: bool = False
|
17 |
+
v_pred: bool = False
|
18 |
+
|
19 |
+
clip_skip: Optional[int] = None
|
20 |
+
|
21 |
+
|
22 |
+
class NetworkConfig(BaseModel):
|
23 |
+
type: NETWORK_TYPES = "lierla"
|
24 |
+
rank: int = 4
|
25 |
+
alpha: float = 1.0
|
26 |
+
|
27 |
+
training_method: TRAINING_METHODS = "full"
|
28 |
+
|
29 |
+
|
30 |
+
class TrainConfig(BaseModel):
|
31 |
+
precision: PRECISION_TYPES = "bfloat16"
|
32 |
+
noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
|
33 |
+
|
34 |
+
iterations: int = 500
|
35 |
+
lr: float = 1e-4
|
36 |
+
optimizer: str = "adamw"
|
37 |
+
optimizer_args: str = ""
|
38 |
+
lr_scheduler: str = "constant"
|
39 |
+
|
40 |
+
max_denoising_steps: int = 50
|
41 |
+
|
42 |
+
|
43 |
+
class SaveConfig(BaseModel):
|
44 |
+
name: str = "untitled"
|
45 |
+
path: str = "./output"
|
46 |
+
per_steps: int = 200
|
47 |
+
precision: PRECISION_TYPES = "float32"
|
48 |
+
|
49 |
+
|
50 |
+
class LoggingConfig(BaseModel):
|
51 |
+
use_wandb: bool = False
|
52 |
+
|
53 |
+
verbose: bool = False
|
54 |
+
|
55 |
+
|
56 |
+
class OtherConfig(BaseModel):
|
57 |
+
use_xformers: bool = False
|
58 |
+
|
59 |
+
|
60 |
+
class RootConfig(BaseModel):
|
61 |
+
prompts_file: str
|
62 |
+
pretrained_model: PretrainedModelConfig
|
63 |
+
|
64 |
+
network: NetworkConfig
|
65 |
+
|
66 |
+
train: Optional[TrainConfig]
|
67 |
+
|
68 |
+
save: Optional[SaveConfig]
|
69 |
+
|
70 |
+
logging: Optional[LoggingConfig]
|
71 |
+
|
72 |
+
other: Optional[OtherConfig]
|
73 |
+
|
74 |
+
|
75 |
+
def parse_precision(precision: str) -> torch.dtype:
|
76 |
+
if precision == "fp32" or precision == "float32":
|
77 |
+
return torch.float32
|
78 |
+
elif precision == "fp16" or precision == "float16":
|
79 |
+
return torch.float16
|
80 |
+
elif precision == "bf16" or precision == "bfloat16":
|
81 |
+
return torch.bfloat16
|
82 |
+
|
83 |
+
raise ValueError(f"Invalid precision type: {precision}")
|
84 |
+
|
85 |
+
|
86 |
+
def load_config_from_yaml(config_path: str) -> RootConfig:
|
87 |
+
with open(config_path, "r") as f:
|
88 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
89 |
+
|
90 |
+
root = RootConfig(**config)
|
91 |
+
|
92 |
+
if root.train is None:
|
93 |
+
root.train = TrainConfig()
|
94 |
+
|
95 |
+
if root.save is None:
|
96 |
+
root.save = SaveConfig()
|
97 |
+
|
98 |
+
if root.logging is None:
|
99 |
+
root.logging = LoggingConfig()
|
100 |
+
|
101 |
+
if root.other is None:
|
102 |
+
root.other = OtherConfig()
|
103 |
+
|
104 |
+
return root
|
trainscripts/textsliders/data/config-xl.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompts_file: "trainscripts/textsliders/data/prompts-xl.yaml"
|
2 |
+
pretrained_model:
|
3 |
+
name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
|
4 |
+
v2: false # true if model is v2.x
|
5 |
+
v_pred: false # true if model uses v-prediction
|
6 |
+
network:
|
7 |
+
type: "c3lier" # or "c3lier" or "lierla"
|
8 |
+
rank: 4
|
9 |
+
alpha: 1.0
|
10 |
+
training_method: "noxattn"
|
11 |
+
train:
|
12 |
+
precision: "bfloat16"
|
13 |
+
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
14 |
+
iterations: 1000
|
15 |
+
lr: 0.0002
|
16 |
+
optimizer: "AdamW"
|
17 |
+
lr_scheduler: "constant"
|
18 |
+
max_denoising_steps: 50
|
19 |
+
save:
|
20 |
+
name: "temp"
|
21 |
+
path: "./models"
|
22 |
+
per_steps: 500
|
23 |
+
precision: "bfloat16"
|
24 |
+
logging:
|
25 |
+
use_wandb: false
|
26 |
+
verbose: false
|
27 |
+
other:
|
28 |
+
use_xformers: true
|
trainscripts/textsliders/data/config.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompts_file: "trainscripts/textsliders/data/prompts.yaml"
|
2 |
+
pretrained_model:
|
3 |
+
name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
|
4 |
+
v2: false # true if model is v2.x
|
5 |
+
v_pred: false # true if model uses v-prediction
|
6 |
+
network:
|
7 |
+
type: "c3lier" # or "c3lier" or "lierla"
|
8 |
+
rank: 4
|
9 |
+
alpha: 1.0
|
10 |
+
training_method: "noxattn"
|
11 |
+
train:
|
12 |
+
precision: "bfloat16"
|
13 |
+
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
14 |
+
iterations: 1000
|
15 |
+
lr: 0.0002
|
16 |
+
optimizer: "AdamW"
|
17 |
+
lr_scheduler: "constant"
|
18 |
+
max_denoising_steps: 50
|
19 |
+
save:
|
20 |
+
name: "temp"
|
21 |
+
path: "./models"
|
22 |
+
per_steps: 500
|
23 |
+
precision: "bfloat16"
|
24 |
+
logging:
|
25 |
+
use_wandb: false
|
26 |
+
verbose: false
|
27 |
+
other:
|
28 |
+
use_xformers: true
|
trainscripts/textsliders/data/prompts-xl.yaml
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
####################################################################################################### AGE SLIDER
|
2 |
+
# - target: "male person" # what word for erasing the positive concept from
|
3 |
+
# positive: "male person, very old" # concept to erase
|
4 |
+
# unconditional: "male person, very young" # word to take the difference from the positive concept
|
5 |
+
# neutral: "male person" # starting point for conditioning the target
|
6 |
+
# action: "enhance" # erase or enhance
|
7 |
+
# guidance_scale: 4
|
8 |
+
# resolution: 512
|
9 |
+
# dynamic_resolution: false
|
10 |
+
# batch_size: 1
|
11 |
+
# - target: "female person" # what word for erasing the positive concept from
|
12 |
+
# positive: "female person, very old" # concept to erase
|
13 |
+
# unconditional: "female person, very young" # word to take the difference from the positive concept
|
14 |
+
# neutral: "female person" # starting point for conditioning the target
|
15 |
+
# action: "enhance" # erase or enhance
|
16 |
+
# guidance_scale: 4
|
17 |
+
# resolution: 512
|
18 |
+
# dynamic_resolution: false
|
19 |
+
# batch_size: 1
|
20 |
+
####################################################################################################### MUSCULAR SLIDER
|
21 |
+
# - target: "male person" # what word for erasing the positive concept from
|
22 |
+
# positive: "male person, muscular, strong, biceps, greek god physique, body builder" # concept to erase
|
23 |
+
# unconditional: "male person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept
|
24 |
+
# neutral: "male person" # starting point for conditioning the target
|
25 |
+
# action: "enhance" # erase or enhance
|
26 |
+
# guidance_scale: 4
|
27 |
+
# resolution: 512
|
28 |
+
# dynamic_resolution: false
|
29 |
+
# batch_size: 1
|
30 |
+
# - target: "female person" # what word for erasing the positive concept from
|
31 |
+
# positive: "female person, muscular, strong, biceps, greek god physique, body builder" # concept to erase
|
32 |
+
# unconditional: "female person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept
|
33 |
+
# neutral: "female person" # starting point for conditioning the target
|
34 |
+
# action: "enhance" # erase or enhance
|
35 |
+
# guidance_scale: 4
|
36 |
+
# resolution: 512
|
37 |
+
# dynamic_resolution: false
|
38 |
+
# batch_size: 1
|
39 |
+
####################################################################################################### CURLY HAIR SLIDER
|
40 |
+
# - target: "male person" # what word for erasing the positive concept from
|
41 |
+
# positive: "male person, curly hair, wavy hair" # concept to erase
|
42 |
+
# unconditional: "male person, straight hair" # word to take the difference from the positive concept
|
43 |
+
# neutral: "male person" # starting point for conditioning the target
|
44 |
+
# action: "enhance" # erase or enhance
|
45 |
+
# guidance_scale: 4
|
46 |
+
# resolution: 512
|
47 |
+
# dynamic_resolution: false
|
48 |
+
# batch_size: 1
|
49 |
+
# - target: "female person" # what word for erasing the positive concept from
|
50 |
+
# positive: "female person, curly hair, wavy hair" # concept to erase
|
51 |
+
# unconditional: "female person, straight hair" # word to take the difference from the positive concept
|
52 |
+
# neutral: "female person" # starting point for conditioning the target
|
53 |
+
# action: "enhance" # erase or enhance
|
54 |
+
# guidance_scale: 4
|
55 |
+
# resolution: 512
|
56 |
+
# dynamic_resolution: false
|
57 |
+
# batch_size: 1
|
58 |
+
####################################################################################################### BEARD SLIDER
|
59 |
+
# - target: "male person" # what word for erasing the positive concept from
|
60 |
+
# positive: "male person, with beard" # concept to erase
|
61 |
+
# unconditional: "male person, clean shaven" # word to take the difference from the positive concept
|
62 |
+
# neutral: "male person" # starting point for conditioning the target
|
63 |
+
# action: "enhance" # erase or enhance
|
64 |
+
# guidance_scale: 4
|
65 |
+
# resolution: 512
|
66 |
+
# dynamic_resolution: false
|
67 |
+
# batch_size: 1
|
68 |
+
# - target: "female person" # what word for erasing the positive concept from
|
69 |
+
# positive: "female person, with beard, lipstick and feminine" # concept to erase
|
70 |
+
# unconditional: "female person, clean shaven" # word to take the difference from the positive concept
|
71 |
+
# neutral: "female person" # starting point for conditioning the target
|
72 |
+
# action: "enhance" # erase or enhance
|
73 |
+
# guidance_scale: 4
|
74 |
+
# resolution: 512
|
75 |
+
# dynamic_resolution: false
|
76 |
+
# batch_size: 1
|
77 |
+
####################################################################################################### MAKEUP SLIDER
|
78 |
+
# - target: "male person" # what word for erasing the positive concept from
|
79 |
+
# positive: "male person, with makeup, cosmetic, concealer, mascara" # concept to erase
|
80 |
+
# unconditional: "male person, barefaced, ugly" # word to take the difference from the positive concept
|
81 |
+
# neutral: "male person" # starting point for conditioning the target
|
82 |
+
# action: "enhance" # erase or enhance
|
83 |
+
# guidance_scale: 4
|
84 |
+
# resolution: 512
|
85 |
+
# dynamic_resolution: false
|
86 |
+
# batch_size: 1
|
87 |
+
# - target: "female person" # what word for erasing the positive concept from
|
88 |
+
# positive: "female person, with makeup, cosmetic, concealer, mascara, lipstick" # concept to erase
|
89 |
+
# unconditional: "female person, barefaced, ugly" # word to take the difference from the positive concept
|
90 |
+
# neutral: "female person" # starting point for conditioning the target
|
91 |
+
# action: "enhance" # erase or enhance
|
92 |
+
# guidance_scale: 4
|
93 |
+
# resolution: 512
|
94 |
+
# dynamic_resolution: false
|
95 |
+
# batch_size: 1
|
96 |
+
####################################################################################################### SURPRISED SLIDER
|
97 |
+
# - target: "male person" # what word for erasing the positive concept from
|
98 |
+
# positive: "male person, with shocked look, surprised, stunned, amazed" # concept to erase
|
99 |
+
# unconditional: "male person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept
|
100 |
+
# neutral: "male person" # starting point for conditioning the target
|
101 |
+
# action: "enhance" # erase or enhance
|
102 |
+
# guidance_scale: 4
|
103 |
+
# resolution: 512
|
104 |
+
# dynamic_resolution: false
|
105 |
+
# batch_size: 1
|
106 |
+
# - target: "female person" # what word for erasing the positive concept from
|
107 |
+
# positive: "female person, with shocked look, surprised, stunned, amazed" # concept to erase
|
108 |
+
# unconditional: "female person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept
|
109 |
+
# neutral: "female person" # starting point for conditioning the target
|
110 |
+
# action: "enhance" # erase or enhance
|
111 |
+
# guidance_scale: 4
|
112 |
+
# resolution: 512
|
113 |
+
# dynamic_resolution: false
|
114 |
+
# batch_size: 1
|
115 |
+
####################################################################################################### OBESE SLIDER
|
116 |
+
# - target: "male person" # what word for erasing the positive concept from
|
117 |
+
# positive: "male person, fat, chubby, overweight, obese" # concept to erase
|
118 |
+
# unconditional: "male person, lean, fit, slim, slender" # word to take the difference from the positive concept
|
119 |
+
# neutral: "male person" # starting point for conditioning the target
|
120 |
+
# action: "enhance" # erase or enhance
|
121 |
+
# guidance_scale: 4
|
122 |
+
# resolution: 512
|
123 |
+
# dynamic_resolution: false
|
124 |
+
# batch_size: 1
|
125 |
+
# - target: "female person" # what word for erasing the positive concept from
|
126 |
+
# positive: "female person, fat, chubby, overweight, obese" # concept to erase
|
127 |
+
# unconditional: "female person, lean, fit, slim, slender" # word to take the difference from the positive concept
|
128 |
+
# neutral: "female person" # starting point for conditioning the target
|
129 |
+
# action: "enhance" # erase or enhance
|
130 |
+
# guidance_scale: 4
|
131 |
+
# resolution: 512
|
132 |
+
# dynamic_resolution: false
|
133 |
+
# batch_size: 1
|
134 |
+
####################################################################################################### PROFESSIONAL SLIDER
|
135 |
+
# - target: "male person" # what word for erasing the positive concept from
|
136 |
+
# positive: "male person, professionally dressed, stylised hair, clean face" # concept to erase
|
137 |
+
# unconditional: "male person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept
|
138 |
+
# neutral: "male person" # starting point for conditioning the target
|
139 |
+
# action: "enhance" # erase or enhance
|
140 |
+
# guidance_scale: 4
|
141 |
+
# resolution: 512
|
142 |
+
# dynamic_resolution: false
|
143 |
+
# batch_size: 1
|
144 |
+
# - target: "female person" # what word for erasing the positive concept from
|
145 |
+
# positive: "female person, professionally dressed, stylised hair, clean face" # concept to erase
|
146 |
+
# unconditional: "female person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept
|
147 |
+
# neutral: "female person" # starting point for conditioning the target
|
148 |
+
# action: "enhance" # erase or enhance
|
149 |
+
# guidance_scale: 4
|
150 |
+
# resolution: 512
|
151 |
+
# dynamic_resolution: false
|
152 |
+
# batch_size: 1
|
153 |
+
####################################################################################################### GLASSES SLIDER
|
154 |
+
# - target: "male person" # what word for erasing the positive concept from
|
155 |
+
# positive: "male person, wearing glasses" # concept to erase
|
156 |
+
# unconditional: "male person" # word to take the difference from the positive concept
|
157 |
+
# neutral: "male person" # starting point for conditioning the target
|
158 |
+
# action: "enhance" # erase or enhance
|
159 |
+
# guidance_scale: 4
|
160 |
+
# resolution: 512
|
161 |
+
# dynamic_resolution: false
|
162 |
+
# batch_size: 1
|
163 |
+
# - target: "female person" # what word for erasing the positive concept from
|
164 |
+
# positive: "female person, wearing glasses" # concept to erase
|
165 |
+
# unconditional: "female person" # word to take the difference from the positive concept
|
166 |
+
# neutral: "female person" # starting point for conditioning the target
|
167 |
+
# action: "enhance" # erase or enhance
|
168 |
+
# guidance_scale: 4
|
169 |
+
# resolution: 512
|
170 |
+
# dynamic_resolution: false
|
171 |
+
# batch_size: 1
|
172 |
+
####################################################################################################### ASTRONAUGHT SLIDER
|
173 |
+
# - target: "astronaught" # what word for erasing the positive concept from
|
174 |
+
# positive: "astronaught, with orange colored spacesuit" # concept to erase
|
175 |
+
# unconditional: "astronaught" # word to take the difference from the positive concept
|
176 |
+
# neutral: "astronaught" # starting point for conditioning the target
|
177 |
+
# action: "enhance" # erase or enhance
|
178 |
+
# guidance_scale: 4
|
179 |
+
# resolution: 512
|
180 |
+
# dynamic_resolution: false
|
181 |
+
# batch_size: 1
|
182 |
+
####################################################################################################### SMILING SLIDER
|
183 |
+
# - target: "male person" # what word for erasing the positive concept from
|
184 |
+
# positive: "male person, smiling" # concept to erase
|
185 |
+
# unconditional: "male person, frowning" # word to take the difference from the positive concept
|
186 |
+
# neutral: "male person" # starting point for conditioning the target
|
187 |
+
# action: "enhance" # erase or enhance
|
188 |
+
# guidance_scale: 4
|
189 |
+
# resolution: 512
|
190 |
+
# dynamic_resolution: false
|
191 |
+
# batch_size: 1
|
192 |
+
# - target: "female person" # what word for erasing the positive concept from
|
193 |
+
# positive: "female person, smiling" # concept to erase
|
194 |
+
# unconditional: "female person, frowning" # word to take the difference from the positive concept
|
195 |
+
# neutral: "female person" # starting point for conditioning the target
|
196 |
+
# action: "enhance" # erase or enhance
|
197 |
+
# guidance_scale: 4
|
198 |
+
# resolution: 512
|
199 |
+
# dynamic_resolution: false
|
200 |
+
# batch_size: 1
|
201 |
+
####################################################################################################### CAR COLOR SLIDER
|
202 |
+
# - target: "car" # what word for erasing the positive concept from
|
203 |
+
# positive: "car, white color" # concept to erase
|
204 |
+
# unconditional: "car, black color" # word to take the difference from the positive concept
|
205 |
+
# neutral: "car" # starting point for conditioning the target
|
206 |
+
# action: "enhance" # erase or enhance
|
207 |
+
# guidance_scale: 4
|
208 |
+
# resolution: 512
|
209 |
+
# dynamic_resolution: false
|
210 |
+
# batch_size: 1
|
211 |
+
####################################################################################################### DETAILS SLIDER
|
212 |
+
# - target: "" # what word for erasing the positive concept from
|
213 |
+
# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase
|
214 |
+
# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
|
215 |
+
# neutral: "" # starting point for conditioning the target
|
216 |
+
# action: "enhance" # erase or enhance
|
217 |
+
# guidance_scale: 4
|
218 |
+
# resolution: 512
|
219 |
+
# dynamic_resolution: false
|
220 |
+
# batch_size: 1
|
221 |
+
####################################################################################################### CARTOON SLIDER
|
222 |
+
# - target: "male person" # what word for erasing the positive concept from
|
223 |
+
# positive: "male person, cartoon style, pixar style, animated style" # concept to erase
|
224 |
+
# unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
|
225 |
+
# neutral: "male person" # starting point for conditioning the target
|
226 |
+
# action: "enhance" # erase or enhance
|
227 |
+
# guidance_scale: 4
|
228 |
+
# resolution: 512
|
229 |
+
# dynamic_resolution: false
|
230 |
+
# batch_size: 1
|
231 |
+
# - target: "female person" # what word for erasing the positive concept from
|
232 |
+
# positive: "female person, cartoon style, pixar style, animated style" # concept to erase
|
233 |
+
# unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
|
234 |
+
# neutral: "female person" # starting point for conditioning the target
|
235 |
+
# action: "enhance" # erase or enhance
|
236 |
+
# guidance_scale: 4
|
237 |
+
# resolution: 512
|
238 |
+
# dynamic_resolution: false
|
239 |
+
# batch_size: 1
|
240 |
+
####################################################################################################### CLAY SLIDER
|
241 |
+
# - target: "male person" # what word for erasing the positive concept from
|
242 |
+
# positive: "male person, clay style, made out of clay, clay sculpture" # concept to erase
|
243 |
+
# unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
|
244 |
+
# neutral: "male person" # starting point for conditioning the target
|
245 |
+
# action: "enhance" # erase or enhance
|
246 |
+
# guidance_scale: 4
|
247 |
+
# resolution: 512
|
248 |
+
# dynamic_resolution: false
|
249 |
+
# batch_size: 1
|
250 |
+
# - target: "female person" # what word for erasing the positive concept from
|
251 |
+
# positive: "female person, clay style, made out of clay, clay sculpture" # concept to erase
|
252 |
+
# unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
|
253 |
+
# neutral: "female person" # starting point for conditioning the target
|
254 |
+
# action: "enhance" # erase or enhance
|
255 |
+
# guidance_scale: 4
|
256 |
+
# resolution: 512
|
257 |
+
# dynamic_resolution: false
|
258 |
+
# batch_size: 1
|
259 |
+
####################################################################################################### SCULPTURE SLIDER
|
260 |
+
- target: "male person" # what word for erasing the positive concept from
|
261 |
+
positive: "male person, cement sculpture, cement greek statue style" # concept to erase
|
262 |
+
unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
|
263 |
+
neutral: "male person" # starting point for conditioning the target
|
264 |
+
action: "enhance" # erase or enhance
|
265 |
+
guidance_scale: 4
|
266 |
+
resolution: 512
|
267 |
+
dynamic_resolution: false
|
268 |
+
batch_size: 1
|
269 |
+
- target: "female person" # what word for erasing the positive concept from
|
270 |
+
positive: "female person, cement sculpture, cement greek statue style" # concept to erase
|
271 |
+
unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
|
272 |
+
neutral: "female person" # starting point for conditioning the target
|
273 |
+
action: "enhance" # erase or enhance
|
274 |
+
guidance_scale: 4
|
275 |
+
resolution: 512
|
276 |
+
dynamic_resolution: false
|
277 |
+
batch_size: 1
|
278 |
+
####################################################################################################### METAL SLIDER
|
279 |
+
# - target: "" # what word for erasing the positive concept from
|
280 |
+
# positive: "made out of metal, metallic style, iron, copper, platinum metal," # concept to erase
|
281 |
+
# unconditional: "wooden style, made out of wood" # word to take the difference from the positive concept
|
282 |
+
# neutral: "" # starting point for conditioning the target
|
283 |
+
# action: "enhance" # erase or enhance
|
284 |
+
# guidance_scale: 4
|
285 |
+
# resolution: 512
|
286 |
+
# dynamic_resolution: false
|
287 |
+
# batch_size: 1
|
288 |
+
####################################################################################################### FESTIVE SLIDER
|
289 |
+
# - target: "" # what word for erasing the positive concept from
|
290 |
+
# positive: "festive, colorful banners, confetti, indian festival decorations, chinese festival decorations, fireworks, parade, cherry, gala, happy, celebrations" # concept to erase
|
291 |
+
# unconditional: "dull, dark, sad, desserted, empty, alone" # word to take the difference from the positive concept
|
292 |
+
# neutral: "" # starting point for conditioning the target
|
293 |
+
# action: "enhance" # erase or enhance
|
294 |
+
# guidance_scale: 4
|
295 |
+
# resolution: 512
|
296 |
+
# dynamic_resolution: false
|
297 |
+
# batch_size: 1
|
298 |
+
####################################################################################################### TROPICAL SLIDER
|
299 |
+
# - target: "" # what word for erasing the positive concept from
|
300 |
+
# positive: "tropical, beach, sunny, hot" # concept to erase
|
301 |
+
# unconditional: "arctic, winter, snow, ice, iceburg, snowfall" # word to take the difference from the positive concept
|
302 |
+
# neutral: "" # starting point for conditioning the target
|
303 |
+
# action: "enhance" # erase or enhance
|
304 |
+
# guidance_scale: 4
|
305 |
+
# resolution: 512
|
306 |
+
# dynamic_resolution: false
|
307 |
+
# batch_size: 1
|
308 |
+
####################################################################################################### MODERN SLIDER
|
309 |
+
# - target: "" # what word for erasing the positive concept from
|
310 |
+
# positive: "modern, futuristic style, trendy, stylish, swank" # concept to erase
|
311 |
+
# unconditional: "ancient, classic style, regal, vintage" # word to take the difference from the positive concept
|
312 |
+
# neutral: "" # starting point for conditioning the target
|
313 |
+
# action: "enhance" # erase or enhance
|
314 |
+
# guidance_scale: 4
|
315 |
+
# resolution: 512
|
316 |
+
# dynamic_resolution: false
|
317 |
+
# batch_size: 1
|
318 |
+
####################################################################################################### BOKEH SLIDER
|
319 |
+
# - target: "" # what word for erasing the positive concept from
|
320 |
+
# positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
|
321 |
+
# # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
|
322 |
+
# unconditional: ""
|
323 |
+
# neutral: "" # starting point for conditioning the target
|
324 |
+
# action: "enhance" # erase or enhance
|
325 |
+
# guidance_scale: 4
|
326 |
+
# resolution: 512
|
327 |
+
# dynamic_resolution: false
|
328 |
+
# batch_size: 1
|
329 |
+
####################################################################################################### LONG HAIR SLIDER
|
330 |
+
# - target: "male person" # what word for erasing the positive concept from
|
331 |
+
# positive: "male person, with long hair" # concept to erase
|
332 |
+
# unconditional: "male person, with short hair" # word to take the difference from the positive concept
|
333 |
+
# neutral: "male person" # starting point for conditioning the target
|
334 |
+
# action: "enhance" # erase or enhance
|
335 |
+
# guidance_scale: 4
|
336 |
+
# resolution: 512
|
337 |
+
# dynamic_resolution: false
|
338 |
+
# batch_size: 1
|
339 |
+
# - target: "female person" # what word for erasing the positive concept from
|
340 |
+
# positive: "female person, with long hair" # concept to erase
|
341 |
+
# unconditional: "female person, with short hair" # word to take the difference from the positive concept
|
342 |
+
# neutral: "female person" # starting point for conditioning the target
|
343 |
+
# action: "enhance" # erase or enhance
|
344 |
+
# guidance_scale: 4
|
345 |
+
# resolution: 512
|
346 |
+
# dynamic_resolution: false
|
347 |
+
# batch_size: 1
|
348 |
+
####################################################################################################### NEGPROMPT SLIDER
|
349 |
+
# - target: "" # what word for erasing the positive concept from
|
350 |
+
# positive: "cartoon, cgi, render, illustration, painting, drawing, bad quality, grainy, low resolution" # concept to erase
|
351 |
+
# unconditional: ""
|
352 |
+
# neutral: "" # starting point for conditioning the target
|
353 |
+
# action: "erase" # erase or enhance
|
354 |
+
# guidance_scale: 4
|
355 |
+
# resolution: 512
|
356 |
+
# dynamic_resolution: false
|
357 |
+
# batch_size: 1
|
358 |
+
####################################################################################################### EXPENSIVE FOOD SLIDER
|
359 |
+
# - target: "food" # what word for erasing the positive concept from
|
360 |
+
# positive: "food, expensive and fine dining" # concept to erase
|
361 |
+
# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
|
362 |
+
# neutral: "food" # starting point for conditioning the target
|
363 |
+
# action: "enhance" # erase or enhance
|
364 |
+
# guidance_scale: 4
|
365 |
+
# resolution: 512
|
366 |
+
# dynamic_resolution: false
|
367 |
+
# batch_size: 1
|
368 |
+
####################################################################################################### COOKED FOOD SLIDER
|
369 |
+
# - target: "food" # what word for erasing the positive concept from
|
370 |
+
# positive: "food, cooked, baked, roasted, fried" # concept to erase
|
371 |
+
# unconditional: "food, raw, uncooked, fresh, undone" # word to take the difference from the positive concept
|
372 |
+
# neutral: "food" # starting point for conditioning the target
|
373 |
+
# action: "enhance" # erase or enhance
|
374 |
+
# guidance_scale: 4
|
375 |
+
# resolution: 512
|
376 |
+
# dynamic_resolution: false
|
377 |
+
# batch_size: 1
|
378 |
+
####################################################################################################### MEAT FOOD SLIDER
|
379 |
+
# - target: "food" # what word for erasing the positive concept from
|
380 |
+
# positive: "food, meat, steak, fish, non-vegetrian, beef, lamb, pork, chicken, salmon" # concept to erase
|
381 |
+
# unconditional: "food, vegetables, fruits, leafy-vegetables, greens, vegetarian, vegan, tomatoes, onions, carrots" # word to take the difference from the positive concept
|
382 |
+
# neutral: "food" # starting point for conditioning the target
|
383 |
+
# action: "enhance" # erase or enhance
|
384 |
+
# guidance_scale: 4
|
385 |
+
# resolution: 512
|
386 |
+
# dynamic_resolution: false
|
387 |
+
# batch_size: 1
|
388 |
+
####################################################################################################### WEATHER SLIDER
|
389 |
+
# - target: "" # what word for erasing the positive concept from
|
390 |
+
# positive: "snowy, winter, cold, ice, snowfall, white" # concept to erase
|
391 |
+
# unconditional: "hot, summer, bright, sunny" # word to take the difference from the positive concept
|
392 |
+
# neutral: "" # starting point for conditioning the target
|
393 |
+
# action: "enhance" # erase or enhance
|
394 |
+
# guidance_scale: 4
|
395 |
+
# resolution: 512
|
396 |
+
# dynamic_resolution: false
|
397 |
+
# batch_size: 1
|
398 |
+
####################################################################################################### NIGHT/DAY SLIDER
|
399 |
+
# - target: "" # what word for erasing the positive concept from
|
400 |
+
# positive: "night time, dark, darkness, pitch black, nighttime" # concept to erase
|
401 |
+
# unconditional: "day time, bright, sunny, daytime, sunlight" # word to take the difference from the positive concept
|
402 |
+
# neutral: "" # starting point for conditioning the target
|
403 |
+
# action: "enhance" # erase or enhance
|
404 |
+
# guidance_scale: 4
|
405 |
+
# resolution: 512
|
406 |
+
# dynamic_resolution: false
|
407 |
+
# batch_size: 1
|
408 |
+
####################################################################################################### INDOOR/OUTDOOR SLIDER
|
409 |
+
# - target: "" # what word for erasing the positive concept from
|
410 |
+
# positive: "indoor, inside a room, inside, interior" # concept to erase
|
411 |
+
# unconditional: "outdoor, outside, open air, exterior" # word to take the difference from the positive concept
|
412 |
+
# neutral: "" # starting point for conditioning the target
|
413 |
+
# action: "enhance" # erase or enhance
|
414 |
+
# guidance_scale: 4
|
415 |
+
# resolution: 512
|
416 |
+
# dynamic_resolution: false
|
417 |
+
# batch_size: 1
|
418 |
+
####################################################################################################### GOODHANDS SLIDER
|
419 |
+
# - target: "" # what word for erasing the positive concept from
|
420 |
+
# positive: "realistic hands, realistic limbs, perfect limbs, perfect hands, 5 fingers, five fingers, hyper realisitc hands" # concept to erase
|
421 |
+
# unconditional: "poorly drawn limbs, distorted limbs, poorly rendered hands,bad anatomy, disfigured, mutated body parts, bad composition" # word to take the difference from the positive concept
|
422 |
+
# neutral: "" # starting point for conditioning the target
|
423 |
+
# action: "enhance" # erase or enhance
|
424 |
+
# guidance_scale: 4
|
425 |
+
# resolution: 512
|
426 |
+
# dynamic_resolution: false
|
427 |
+
# batch_size: 1
|
428 |
+
####################################################################################################### RUSTY CAR SLIDER
|
429 |
+
# - target: "car" # what word for erasing the positive concept from
|
430 |
+
# positive: "car, rusty conditioned" # concept to erase
|
431 |
+
# unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept
|
432 |
+
# neutral: "car" # starting point for conditioning the target
|
433 |
+
# action: "enhance" # erase or enhance
|
434 |
+
# guidance_scale: 4
|
435 |
+
# resolution: 512
|
436 |
+
# dynamic_resolution: false
|
437 |
+
# batch_size: 1
|
438 |
+
####################################################################################################### RUSTY CAR SLIDER
|
439 |
+
# - target: "car" # what word for erasing the positive concept from
|
440 |
+
# positive: "car, damaged, broken headlights, dented car, with scrapped paintwork" # concept to erase
|
441 |
+
# unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept
|
442 |
+
# neutral: "car" # starting point for conditioning the target
|
443 |
+
# action: "enhance" # erase or enhance
|
444 |
+
# guidance_scale: 4
|
445 |
+
# resolution: 512
|
446 |
+
# dynamic_resolution: false
|
447 |
+
# batch_size: 1
|
448 |
+
####################################################################################################### CLUTTERED ROOM SLIDER
|
449 |
+
# - target: "room" # what word for erasing the positive concept from
|
450 |
+
# positive: "room, cluttered, disorganized, dirty, jumbled, scattered" # concept to erase
|
451 |
+
# unconditional: "room, super organized, clean, ordered, neat, tidy" # word to take the difference from the positive concept
|
452 |
+
# neutral: "room" # starting point for conditioning the target
|
453 |
+
# action: "enhance" # erase or enhance
|
454 |
+
# guidance_scale: 4
|
455 |
+
# resolution: 512
|
456 |
+
# dynamic_resolution: false
|
457 |
+
# batch_size: 1
|
458 |
+
####################################################################################################### HANDS SLIDER
|
459 |
+
# - target: "hands" # what word for erasing the positive concept from
|
460 |
+
# positive: "realistic hands, five fingers, 8k hyper realistic hands" # concept to erase
|
461 |
+
# unconditional: "poorly drawn hands, distorted hands, amputed fingers" # word to take the difference from the positive concept
|
462 |
+
# neutral: "hands" # starting point for conditioning the target
|
463 |
+
# action: "enhance" # erase or enhance
|
464 |
+
# guidance_scale: 4
|
465 |
+
# resolution: 512
|
466 |
+
# dynamic_resolution: false
|
467 |
+
# batch_size: 1
|
468 |
+
####################################################################################################### HANDS SLIDER
|
469 |
+
# - target: "female person" # what word for erasing the positive concept from
|
470 |
+
# positive: "female person, with a surprised look" # concept to erase
|
471 |
+
# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
|
472 |
+
# neutral: "female person" # starting point for conditioning the target
|
473 |
+
# action: "enhance" # erase or enhance
|
474 |
+
# guidance_scale: 4
|
475 |
+
# resolution: 512
|
476 |
+
# dynamic_resolution: false
|
477 |
+
# batch_size: 1
|
trainscripts/textsliders/data/prompts.yaml
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- target: "male person" # what word for erasing the positive concept from
|
2 |
+
positive: "male person, very old" # concept to erase
|
3 |
+
unconditional: "male person, very young" # word to take the difference from the positive concept
|
4 |
+
neutral: "male person" # starting point for conditioning the target
|
5 |
+
action: "enhance" # erase or enhance
|
6 |
+
guidance_scale: 4
|
7 |
+
resolution: 512
|
8 |
+
dynamic_resolution: false
|
9 |
+
batch_size: 1
|
10 |
+
- target: "female person" # what word for erasing the positive concept from
|
11 |
+
positive: "female person, very old" # concept to erase
|
12 |
+
unconditional: "female person, very young" # word to take the difference from the positive concept
|
13 |
+
neutral: "female person" # starting point for conditioning the target
|
14 |
+
action: "enhance" # erase or enhance
|
15 |
+
guidance_scale: 4
|
16 |
+
resolution: 512
|
17 |
+
dynamic_resolution: false
|
18 |
+
batch_size: 1
|
19 |
+
# - target: "" # what word for erasing the positive concept from
|
20 |
+
# positive: "a group of people" # concept to erase
|
21 |
+
# unconditional: "a person" # word to take the difference from the positive concept
|
22 |
+
# neutral: "" # starting point for conditioning the target
|
23 |
+
# action: "enhance" # erase or enhance
|
24 |
+
# guidance_scale: 4
|
25 |
+
# resolution: 512
|
26 |
+
# dynamic_resolution: false
|
27 |
+
# batch_size: 1
|
28 |
+
# - target: "" # what word for erasing the positive concept from
|
29 |
+
# positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
|
30 |
+
# unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
|
31 |
+
# neutral: "" # starting point for conditioning the target
|
32 |
+
# action: "enhance" # erase or enhance
|
33 |
+
# guidance_scale: 4
|
34 |
+
# resolution: 512
|
35 |
+
# dynamic_resolution: false
|
36 |
+
# batch_size: 1
|
37 |
+
# - target: "" # what word for erasing the positive concept from
|
38 |
+
# positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
|
39 |
+
# # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
|
40 |
+
# unconditional: ""
|
41 |
+
# neutral: "" # starting point for conditioning the target
|
42 |
+
# action: "enhance" # erase or enhance
|
43 |
+
# guidance_scale: 4
|
44 |
+
# resolution: 512
|
45 |
+
# dynamic_resolution: false
|
46 |
+
# batch_size: 1
|
47 |
+
# - target: "food" # what word for erasing the positive concept from
|
48 |
+
# positive: "food, expensive and fine dining" # concept to erase
|
49 |
+
# unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
|
50 |
+
# neutral: "food" # starting point for conditioning the target
|
51 |
+
# action: "enhance" # erase or enhance
|
52 |
+
# guidance_scale: 4
|
53 |
+
# resolution: 512
|
54 |
+
# dynamic_resolution: false
|
55 |
+
# batch_size: 1
|
56 |
+
# - target: "room" # what word for erasing the positive concept from
|
57 |
+
# positive: "room, dirty disorganised and cluttered" # concept to erase
|
58 |
+
# unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
|
59 |
+
# neutral: "room" # starting point for conditioning the target
|
60 |
+
# action: "enhance" # erase or enhance
|
61 |
+
# guidance_scale: 4
|
62 |
+
# resolution: 512
|
63 |
+
# dynamic_resolution: false
|
64 |
+
# batch_size: 1
|
65 |
+
# - target: "male person" # what word for erasing the positive concept from
|
66 |
+
# positive: "male person, with a surprised look" # concept to erase
|
67 |
+
# unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
|
68 |
+
# neutral: "male person" # starting point for conditioning the target
|
69 |
+
# action: "enhance" # erase or enhance
|
70 |
+
# guidance_scale: 4
|
71 |
+
# resolution: 512
|
72 |
+
# dynamic_resolution: false
|
73 |
+
# batch_size: 1
|
74 |
+
# - target: "female person" # what word for erasing the positive concept from
|
75 |
+
# positive: "female person, with a surprised look" # concept to erase
|
76 |
+
# unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
|
77 |
+
# neutral: "female person" # starting point for conditioning the target
|
78 |
+
# action: "enhance" # erase or enhance
|
79 |
+
# guidance_scale: 4
|
80 |
+
# resolution: 512
|
81 |
+
# dynamic_resolution: false
|
82 |
+
# batch_size: 1
|
83 |
+
# - target: "sky" # what word for erasing the positive concept from
|
84 |
+
# positive: "peaceful sky" # concept to erase
|
85 |
+
# unconditional: "sky" # word to take the difference from the positive concept
|
86 |
+
# neutral: "sky" # starting point for conditioning the target
|
87 |
+
# action: "enhance" # erase or enhance
|
88 |
+
# guidance_scale: 4
|
89 |
+
# resolution: 512
|
90 |
+
# dynamic_resolution: false
|
91 |
+
# batch_size: 1
|
92 |
+
# - target: "sky" # what word for erasing the positive concept from
|
93 |
+
# positive: "chaotic dark sky" # concept to erase
|
94 |
+
# unconditional: "sky" # word to take the difference from the positive concept
|
95 |
+
# neutral: "sky" # starting point for conditioning the target
|
96 |
+
# action: "erase" # erase or enhance
|
97 |
+
# guidance_scale: 4
|
98 |
+
# resolution: 512
|
99 |
+
# dynamic_resolution: false
|
100 |
+
# batch_size: 1
|
101 |
+
# - target: "person" # what word for erasing the positive concept from
|
102 |
+
# positive: "person, very young" # concept to erase
|
103 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
104 |
+
# neutral: "person" # starting point for conditioning the target
|
105 |
+
# action: "erase" # erase or enhance
|
106 |
+
# guidance_scale: 4
|
107 |
+
# resolution: 512
|
108 |
+
# dynamic_resolution: false
|
109 |
+
# batch_size: 1
|
110 |
+
# overweight
|
111 |
+
# - target: "art" # what word for erasing the positive concept from
|
112 |
+
# positive: "realistic art" # concept to erase
|
113 |
+
# unconditional: "art" # word to take the difference from the positive concept
|
114 |
+
# neutral: "art" # starting point for conditioning the target
|
115 |
+
# action: "enhance" # erase or enhance
|
116 |
+
# guidance_scale: 4
|
117 |
+
# resolution: 512
|
118 |
+
# dynamic_resolution: false
|
119 |
+
# batch_size: 1
|
120 |
+
# - target: "art" # what word for erasing the positive concept from
|
121 |
+
# positive: "abstract art" # concept to erase
|
122 |
+
# unconditional: "art" # word to take the difference from the positive concept
|
123 |
+
# neutral: "art" # starting point for conditioning the target
|
124 |
+
# action: "erase" # erase or enhance
|
125 |
+
# guidance_scale: 4
|
126 |
+
# resolution: 512
|
127 |
+
# dynamic_resolution: false
|
128 |
+
# batch_size: 1
|
129 |
+
# sky
|
130 |
+
# - target: "weather" # what word for erasing the positive concept from
|
131 |
+
# positive: "bright pleasant weather" # concept to erase
|
132 |
+
# unconditional: "weather" # word to take the difference from the positive concept
|
133 |
+
# neutral: "weather" # starting point for conditioning the target
|
134 |
+
# action: "enhance" # erase or enhance
|
135 |
+
# guidance_scale: 4
|
136 |
+
# resolution: 512
|
137 |
+
# dynamic_resolution: false
|
138 |
+
# batch_size: 1
|
139 |
+
# - target: "weather" # what word for erasing the positive concept from
|
140 |
+
# positive: "dark gloomy weather" # concept to erase
|
141 |
+
# unconditional: "weather" # word to take the difference from the positive concept
|
142 |
+
# neutral: "weather" # starting point for conditioning the target
|
143 |
+
# action: "erase" # erase or enhance
|
144 |
+
# guidance_scale: 4
|
145 |
+
# resolution: 512
|
146 |
+
# dynamic_resolution: false
|
147 |
+
# batch_size: 1
|
148 |
+
# hair
|
149 |
+
# - target: "person" # what word for erasing the positive concept from
|
150 |
+
# positive: "person with long hair" # concept to erase
|
151 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
152 |
+
# neutral: "person" # starting point for conditioning the target
|
153 |
+
# action: "enhance" # erase or enhance
|
154 |
+
# guidance_scale: 4
|
155 |
+
# resolution: 512
|
156 |
+
# dynamic_resolution: false
|
157 |
+
# batch_size: 1
|
158 |
+
# - target: "person" # what word for erasing the positive concept from
|
159 |
+
# positive: "person with short hair" # concept to erase
|
160 |
+
# unconditional: "person" # word to take the difference from the positive concept
|
161 |
+
# neutral: "person" # starting point for conditioning the target
|
162 |
+
# action: "erase" # erase or enhance
|
163 |
+
# guidance_scale: 4
|
164 |
+
# resolution: 512
|
165 |
+
# dynamic_resolution: false
|
166 |
+
# batch_size: 1
|
167 |
+
# - target: "girl" # what word for erasing the positive concept from
|
168 |
+
# positive: "baby girl" # concept to erase
|
169 |
+
# unconditional: "girl" # word to take the difference from the positive concept
|
170 |
+
# neutral: "girl" # starting point for conditioning the target
|
171 |
+
# action: "enhance" # erase or enhance
|
172 |
+
# guidance_scale: -4
|
173 |
+
# resolution: 512
|
174 |
+
# dynamic_resolution: false
|
175 |
+
# batch_size: 1
|
176 |
+
# - target: "boy" # what word for erasing the positive concept from
|
177 |
+
# positive: "old man" # concept to erase
|
178 |
+
# unconditional: "boy" # word to take the difference from the positive concept
|
179 |
+
# neutral: "boy" # starting point for conditioning the target
|
180 |
+
# action: "enhance" # erase or enhance
|
181 |
+
# guidance_scale: 4
|
182 |
+
# resolution: 512
|
183 |
+
# dynamic_resolution: false
|
184 |
+
# batch_size: 1
|
185 |
+
# - target: "boy" # what word for erasing the positive concept from
|
186 |
+
# positive: "baby boy" # concept to erase
|
187 |
+
# unconditional: "boy" # word to take the difference from the positive concept
|
188 |
+
# neutral: "boy" # starting point for conditioning the target
|
189 |
+
# action: "enhance" # erase or enhance
|
190 |
+
# guidance_scale: -4
|
191 |
+
# resolution: 512
|
192 |
+
# dynamic_resolution: false
|
193 |
+
# batch_size: 1
|
trainscripts/textsliders/debug_util.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ใใใใฐ็จ...
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def check_requires_grad(model: torch.nn.Module):
|
7 |
+
for name, module in list(model.named_modules())[:5]:
|
8 |
+
if len(list(module.parameters())) > 0:
|
9 |
+
print(f"Module: {name}")
|
10 |
+
for name, param in list(module.named_parameters())[:2]:
|
11 |
+
print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
|
12 |
+
|
13 |
+
|
14 |
+
def check_training_mode(model: torch.nn.Module):
|
15 |
+
for name, module in list(model.named_modules())[:5]:
|
16 |
+
print(f"Module: {name}, Training Mode: {module.training}")
|
trainscripts/textsliders/flush.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gc
|
3 |
+
|
4 |
+
torch.cuda.empty_cache()
|
5 |
+
gc.collect()
|
trainscripts/textsliders/generate_images_xl.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import argparse
|
4 |
+
import os, json, random
|
5 |
+
import pandas as pd
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import glob, re
|
8 |
+
|
9 |
+
from safetensors.torch import load_file
|
10 |
+
import matplotlib.image as mpimg
|
11 |
+
import copy
|
12 |
+
import gc
|
13 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
14 |
+
|
15 |
+
import diffusers
|
16 |
+
from diffusers import DiffusionPipeline
|
17 |
+
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
|
18 |
+
from diffusers.loaders import AttnProcsLayers
|
19 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
|
20 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
21 |
+
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
|
22 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
23 |
+
import inspect
|
24 |
+
import os
|
25 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
26 |
+
from diffusers.pipelines import StableDiffusionXLPipeline
|
27 |
+
import random
|
28 |
+
|
29 |
+
import torch
|
30 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
31 |
+
import re
|
32 |
+
import argparse
|
33 |
+
|
34 |
+
def flush():
|
35 |
+
torch.cuda.empty_cache()
|
36 |
+
gc.collect()
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def call(
|
40 |
+
self,
|
41 |
+
prompt: Union[str, List[str]] = None,
|
42 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
43 |
+
height: Optional[int] = None,
|
44 |
+
width: Optional[int] = None,
|
45 |
+
num_inference_steps: int = 50,
|
46 |
+
denoising_end: Optional[float] = None,
|
47 |
+
guidance_scale: float = 5.0,
|
48 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
49 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
50 |
+
num_images_per_prompt: Optional[int] = 1,
|
51 |
+
eta: float = 0.0,
|
52 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
53 |
+
latents: Optional[torch.FloatTensor] = None,
|
54 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
55 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
56 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
57 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
58 |
+
output_type: Optional[str] = "pil",
|
59 |
+
return_dict: bool = True,
|
60 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
61 |
+
callback_steps: int = 1,
|
62 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
63 |
+
guidance_rescale: float = 0.0,
|
64 |
+
original_size: Optional[Tuple[int, int]] = None,
|
65 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
66 |
+
target_size: Optional[Tuple[int, int]] = None,
|
67 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
68 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
69 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
70 |
+
|
71 |
+
network=None,
|
72 |
+
start_noise=None,
|
73 |
+
scale=None,
|
74 |
+
unet=None,
|
75 |
+
):
|
76 |
+
r"""
|
77 |
+
Function invoked when calling the pipeline for generation.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
prompt (`str` or `List[str]`, *optional*):
|
81 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
82 |
+
instead.
|
83 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
84 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
85 |
+
used in both text-encoders
|
86 |
+
height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
|
87 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
88 |
+
Anything below 512 pixels won't work well for
|
89 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
90 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
91 |
+
width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
|
92 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
93 |
+
Anything below 512 pixels won't work well for
|
94 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
95 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
96 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
97 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
98 |
+
expense of slower inference.
|
99 |
+
denoising_end (`float`, *optional*):
|
100 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
101 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
102 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
103 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
104 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
105 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
106 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
107 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
108 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
109 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
110 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
111 |
+
usually at the expense of lower image quality.
|
112 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
113 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
114 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
115 |
+
less than `1`).
|
116 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
117 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
118 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
119 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
120 |
+
The number of images to generate per prompt.
|
121 |
+
eta (`float`, *optional*, defaults to 0.0):
|
122 |
+
Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
123 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
124 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
125 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
126 |
+
to make generation deterministic.
|
127 |
+
latents (`torch.FloatTensor`, *optional*):
|
128 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
129 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
130 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
131 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
132 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
133 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
134 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
135 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
136 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
137 |
+
argument.
|
138 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
139 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
140 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
141 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
142 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
143 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
144 |
+
input argument.
|
145 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
146 |
+
The output format of the generate image. Choose between
|
147 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
148 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
149 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
150 |
+
of a plain tuple.
|
151 |
+
callback (`Callable`, *optional*):
|
152 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
153 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
154 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
155 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
156 |
+
called at every step.
|
157 |
+
cross_attention_kwargs (`dict`, *optional*):
|
158 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
159 |
+
`self.processor` in
|
160 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
161 |
+
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
162 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
163 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `ฯ` in equation 16. of
|
164 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
165 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
166 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
167 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
168 |
+
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
169 |
+
explained in section 2.2 of
|
170 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
171 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
172 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
173 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
174 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
175 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
176 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
177 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
178 |
+
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
179 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
180 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
181 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
182 |
+
micro-conditioning as explained in section 2.2 of
|
183 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
184 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
185 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
186 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
187 |
+
micro-conditioning as explained in section 2.2 of
|
188 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
189 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
190 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
191 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
192 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
193 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
194 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
195 |
+
|
196 |
+
Examples:
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
200 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
201 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
202 |
+
"""
|
203 |
+
# 0. Default height and width to unet
|
204 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
205 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
206 |
+
|
207 |
+
original_size = original_size or (height, width)
|
208 |
+
target_size = target_size or (height, width)
|
209 |
+
|
210 |
+
# 1. Check inputs. Raise error if not correct
|
211 |
+
self.check_inputs(
|
212 |
+
prompt,
|
213 |
+
prompt_2,
|
214 |
+
height,
|
215 |
+
width,
|
216 |
+
callback_steps,
|
217 |
+
negative_prompt,
|
218 |
+
negative_prompt_2,
|
219 |
+
prompt_embeds,
|
220 |
+
negative_prompt_embeds,
|
221 |
+
pooled_prompt_embeds,
|
222 |
+
negative_pooled_prompt_embeds,
|
223 |
+
)
|
224 |
+
|
225 |
+
# 2. Define call parameters
|
226 |
+
if prompt is not None and isinstance(prompt, str):
|
227 |
+
batch_size = 1
|
228 |
+
elif prompt is not None and isinstance(prompt, list):
|
229 |
+
batch_size = len(prompt)
|
230 |
+
else:
|
231 |
+
batch_size = prompt_embeds.shape[0]
|
232 |
+
|
233 |
+
device = self._execution_device
|
234 |
+
|
235 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
236 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
237 |
+
# corresponds to doing no classifier free guidance.
|
238 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
239 |
+
|
240 |
+
# 3. Encode input prompt
|
241 |
+
text_encoder_lora_scale = (
|
242 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
243 |
+
)
|
244 |
+
(
|
245 |
+
prompt_embeds,
|
246 |
+
negative_prompt_embeds,
|
247 |
+
pooled_prompt_embeds,
|
248 |
+
negative_pooled_prompt_embeds,
|
249 |
+
) = self.encode_prompt(
|
250 |
+
prompt=prompt,
|
251 |
+
prompt_2=prompt_2,
|
252 |
+
device=device,
|
253 |
+
num_images_per_prompt=num_images_per_prompt,
|
254 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
255 |
+
negative_prompt=negative_prompt,
|
256 |
+
negative_prompt_2=negative_prompt_2,
|
257 |
+
prompt_embeds=prompt_embeds,
|
258 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
259 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
260 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
261 |
+
lora_scale=text_encoder_lora_scale,
|
262 |
+
)
|
263 |
+
|
264 |
+
# 4. Prepare timesteps
|
265 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
266 |
+
|
267 |
+
timesteps = self.scheduler.timesteps
|
268 |
+
|
269 |
+
# 5. Prepare latent variables
|
270 |
+
num_channels_latents = unet.config.in_channels
|
271 |
+
latents = self.prepare_latents(
|
272 |
+
batch_size * num_images_per_prompt,
|
273 |
+
num_channels_latents,
|
274 |
+
height,
|
275 |
+
width,
|
276 |
+
prompt_embeds.dtype,
|
277 |
+
device,
|
278 |
+
generator,
|
279 |
+
latents,
|
280 |
+
)
|
281 |
+
|
282 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
283 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
284 |
+
|
285 |
+
# 7. Prepare added time ids & embeddings
|
286 |
+
add_text_embeds = pooled_prompt_embeds
|
287 |
+
add_time_ids = self._get_add_time_ids(
|
288 |
+
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
289 |
+
)
|
290 |
+
if negative_original_size is not None and negative_target_size is not None:
|
291 |
+
negative_add_time_ids = self._get_add_time_ids(
|
292 |
+
negative_original_size,
|
293 |
+
negative_crops_coords_top_left,
|
294 |
+
negative_target_size,
|
295 |
+
dtype=prompt_embeds.dtype,
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
negative_add_time_ids = add_time_ids
|
299 |
+
|
300 |
+
if do_classifier_free_guidance:
|
301 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
302 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
303 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
304 |
+
|
305 |
+
prompt_embeds = prompt_embeds.to(device)
|
306 |
+
add_text_embeds = add_text_embeds.to(device)
|
307 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
308 |
+
|
309 |
+
# 8. Denoising loop
|
310 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
311 |
+
|
312 |
+
# 7.1 Apply denoising_end
|
313 |
+
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
314 |
+
discrete_timestep_cutoff = int(
|
315 |
+
round(
|
316 |
+
self.scheduler.config.num_train_timesteps
|
317 |
+
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
318 |
+
)
|
319 |
+
)
|
320 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
321 |
+
timesteps = timesteps[:num_inference_steps]
|
322 |
+
|
323 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
324 |
+
for i, t in enumerate(timesteps):
|
325 |
+
if t>start_noise:
|
326 |
+
network.set_lora_slider(scale=0)
|
327 |
+
else:
|
328 |
+
network.set_lora_slider(scale=scale)
|
329 |
+
# expand the latents if we are doing classifier free guidance
|
330 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
331 |
+
|
332 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
333 |
+
|
334 |
+
# predict the noise residual
|
335 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
336 |
+
with network:
|
337 |
+
noise_pred = unet(
|
338 |
+
latent_model_input,
|
339 |
+
t,
|
340 |
+
encoder_hidden_states=prompt_embeds,
|
341 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
342 |
+
added_cond_kwargs=added_cond_kwargs,
|
343 |
+
return_dict=False,
|
344 |
+
)[0]
|
345 |
+
|
346 |
+
# perform guidance
|
347 |
+
if do_classifier_free_guidance:
|
348 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
349 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
350 |
+
|
351 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
352 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
353 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
354 |
+
|
355 |
+
# compute the previous noisy sample x_t -> x_t-1
|
356 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
357 |
+
|
358 |
+
# call the callback, if provided
|
359 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
360 |
+
progress_bar.update()
|
361 |
+
if callback is not None and i % callback_steps == 0:
|
362 |
+
callback(i, t, latents)
|
363 |
+
|
364 |
+
if not output_type == "latent":
|
365 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
366 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
367 |
+
|
368 |
+
if needs_upcasting:
|
369 |
+
self.upcast_vae()
|
370 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
371 |
+
|
372 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
373 |
+
|
374 |
+
# cast back to fp16 if needed
|
375 |
+
if needs_upcasting:
|
376 |
+
self.vae.to(dtype=torch.float16)
|
377 |
+
else:
|
378 |
+
image = latents
|
379 |
+
|
380 |
+
if not output_type == "latent":
|
381 |
+
# apply watermark if available
|
382 |
+
if self.watermark is not None:
|
383 |
+
image = self.watermark.apply_watermark(image)
|
384 |
+
|
385 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
386 |
+
|
387 |
+
# Offload all models
|
388 |
+
# self.maybe_free_model_hooks()
|
389 |
+
|
390 |
+
if not return_dict:
|
391 |
+
return (image,)
|
392 |
+
|
393 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
394 |
+
|
395 |
+
|
396 |
+
def sorted_nicely( l ):
|
397 |
+
convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text
|
398 |
+
alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ]
|
399 |
+
return sorted(l, key = alphanum_key)
|
400 |
+
|
401 |
+
def flush():
|
402 |
+
torch.cuda.empty_cache()
|
403 |
+
gc.collect()
|
404 |
+
|
405 |
+
|
406 |
+
if __name__=='__main__':
|
407 |
+
|
408 |
+
device = 'cuda:0'
|
409 |
+
StableDiffusionXLPipeline.__call__ = call
|
410 |
+
pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0')
|
411 |
+
|
412 |
+
# pipe.__call__ = call
|
413 |
+
pipe = pipe.to(device)
|
414 |
+
|
415 |
+
|
416 |
+
parser = argparse.ArgumentParser(
|
417 |
+
prog = 'generateImages',
|
418 |
+
description = 'Generate Images using Diffusers Code')
|
419 |
+
parser.add_argument('--model_name', help='name of model', type=str, required=True)
|
420 |
+
parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
|
421 |
+
parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None)
|
422 |
+
parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
|
423 |
+
parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4')
|
424 |
+
parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
|
425 |
+
parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
|
426 |
+
parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000)
|
427 |
+
parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
|
428 |
+
parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=5)
|
429 |
+
parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50)
|
430 |
+
parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4)
|
431 |
+
parser.add_argument('--start_noise', help='what time stamp to flip to edited model', type=int, required=False, default=750)
|
432 |
+
|
433 |
+
args = parser.parse_args()
|
434 |
+
lora_weight = args.model_name
|
435 |
+
csv_path = args.prompts_path
|
436 |
+
save_path = args.save_path
|
437 |
+
start_noise = args.start_noise
|
438 |
+
from_case = args.from_case
|
439 |
+
till_case = args.till_case
|
440 |
+
|
441 |
+
weight_dtype = torch.float16
|
442 |
+
num_images_per_prompt = 1
|
443 |
+
scales = [-2, -1, 0, 1, 2]
|
444 |
+
scales = [-1, -.5, 0, .5, 1]
|
445 |
+
scales = [-2]
|
446 |
+
df = pd.read_csv(csv_path)
|
447 |
+
|
448 |
+
for scale in scales:
|
449 |
+
os.makedirs(f'{save_path}/{os.path.basename(lora_weight)}/{scale}', exist_ok=True)
|
450 |
+
|
451 |
+
prompts = list(df['prompt'])
|
452 |
+
seeds = list(df['evaluation_seed'])
|
453 |
+
case_numbers = list(df['case_number'])
|
454 |
+
pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',torch_dtype=torch.float16,)
|
455 |
+
|
456 |
+
# pipe.__call__ = call
|
457 |
+
pipe = pipe.to(device)
|
458 |
+
unet = pipe.unet
|
459 |
+
if 'full' in lora_weight:
|
460 |
+
train_method = 'full'
|
461 |
+
elif 'noxattn' in lora_weight:
|
462 |
+
train_method = 'noxattn'
|
463 |
+
else:
|
464 |
+
train_method = 'noxattn'
|
465 |
+
|
466 |
+
network_type = "c3lier"
|
467 |
+
if train_method == 'xattn':
|
468 |
+
network_type = 'lierla'
|
469 |
+
|
470 |
+
modules = DEFAULT_TARGET_REPLACE
|
471 |
+
if network_type == "c3lier":
|
472 |
+
modules += UNET_TARGET_REPLACE_MODULE_CONV
|
473 |
+
import os
|
474 |
+
model_name = lora_weight
|
475 |
+
|
476 |
+
name = os.path.basename(model_name)
|
477 |
+
rank = 1
|
478 |
+
alpha = 4
|
479 |
+
if 'rank4' in lora_weight:
|
480 |
+
rank = 4
|
481 |
+
if 'rank8' in lora_weight:
|
482 |
+
rank = 8
|
483 |
+
if 'alpha1' in lora_weight:
|
484 |
+
alpha = 1.0
|
485 |
+
network = LoRANetwork(
|
486 |
+
unet,
|
487 |
+
rank=rank,
|
488 |
+
multiplier=1.0,
|
489 |
+
alpha=alpha,
|
490 |
+
train_method=train_method,
|
491 |
+
).to(device, dtype=weight_dtype)
|
492 |
+
network.load_state_dict(torch.load(lora_weight))
|
493 |
+
|
494 |
+
for idx, prompt in enumerate(prompts):
|
495 |
+
seed = seeds[idx]
|
496 |
+
case_number = case_numbers[idx]
|
497 |
+
|
498 |
+
if not (case_number>=from_case and case_number<=till_case):
|
499 |
+
continue
|
500 |
+
if os.path.exists(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png'):
|
501 |
+
continue
|
502 |
+
print(prompt, seed)
|
503 |
+
for scale in scales:
|
504 |
+
generator = torch.manual_seed(seed)
|
505 |
+
images = pipe(prompt, num_images_per_prompt=args.num_samples, num_inference_steps=50, generator=generator, network=network, start_noise=start_noise, scale=scale, unet=unet).images
|
506 |
+
for idx, im in enumerate(images):
|
507 |
+
im.save(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png')
|
508 |
+
del unet, network, pipe
|
509 |
+
unet = None
|
510 |
+
network = None
|
511 |
+
pipe = None
|
512 |
+
torch.cuda.empty_cache()
|
513 |
+
flush()
|
trainscripts/textsliders/lora.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:
|
2 |
+
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
3 |
+
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
4 |
+
|
5 |
+
import os
|
6 |
+
import math
|
7 |
+
from typing import Optional, List, Type, Set, Literal
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from diffusers import UNet2DConditionModel
|
12 |
+
from safetensors.torch import save_file
|
13 |
+
|
14 |
+
|
15 |
+
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
|
16 |
+
# "Transformer2DModel", # ใฉใใใใใฃใกใฎๆนใใใ๏ผ # attn1, 2
|
17 |
+
"Attention"
|
18 |
+
]
|
19 |
+
UNET_TARGET_REPLACE_MODULE_CONV = [
|
20 |
+
"ResnetBlock2D",
|
21 |
+
"Downsample2D",
|
22 |
+
"Upsample2D",
|
23 |
+
"DownBlock2D",
|
24 |
+
"UpBlock2D",
|
25 |
+
|
26 |
+
] # locon, 3clier
|
27 |
+
|
28 |
+
LORA_PREFIX_UNET = "lora_unet"
|
29 |
+
|
30 |
+
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
|
31 |
+
|
32 |
+
TRAINING_METHODS = Literal[
|
33 |
+
"noxattn", # train all layers except x-attns and time_embed layers
|
34 |
+
"innoxattn", # train all layers except self attention layers
|
35 |
+
"selfattn", # ESD-u, train only self attention layers
|
36 |
+
"xattn", # ESD-x, train only x attention layers
|
37 |
+
"full", # train all layers
|
38 |
+
"xattn-strict", # q and k values
|
39 |
+
"noxattn-hspace",
|
40 |
+
"noxattn-hspace-last",
|
41 |
+
# "xlayer",
|
42 |
+
# "outxattn",
|
43 |
+
# "outsattn",
|
44 |
+
# "inxattn",
|
45 |
+
# "inmidsattn",
|
46 |
+
# "selflayer",
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
class LoRAModule(nn.Module):
|
51 |
+
"""
|
52 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
lora_name,
|
58 |
+
org_module: nn.Module,
|
59 |
+
multiplier=1.0,
|
60 |
+
lora_dim=4,
|
61 |
+
alpha=1,
|
62 |
+
):
|
63 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
64 |
+
super().__init__()
|
65 |
+
self.lora_name = lora_name
|
66 |
+
self.lora_dim = lora_dim
|
67 |
+
|
68 |
+
if "Linear" in org_module.__class__.__name__:
|
69 |
+
in_dim = org_module.in_features
|
70 |
+
out_dim = org_module.out_features
|
71 |
+
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
72 |
+
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
73 |
+
|
74 |
+
elif "Conv" in org_module.__class__.__name__: # ไธๅฟ
|
75 |
+
in_dim = org_module.in_channels
|
76 |
+
out_dim = org_module.out_channels
|
77 |
+
|
78 |
+
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
79 |
+
if self.lora_dim != lora_dim:
|
80 |
+
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
81 |
+
|
82 |
+
kernel_size = org_module.kernel_size
|
83 |
+
stride = org_module.stride
|
84 |
+
padding = org_module.padding
|
85 |
+
self.lora_down = nn.Conv2d(
|
86 |
+
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
|
87 |
+
)
|
88 |
+
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
89 |
+
|
90 |
+
if type(alpha) == torch.Tensor:
|
91 |
+
alpha = alpha.detach().numpy()
|
92 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
93 |
+
self.scale = alpha / self.lora_dim
|
94 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # ๅฎๆฐใจใใฆๆฑใใ
|
95 |
+
|
96 |
+
# same as microsoft's
|
97 |
+
nn.init.kaiming_uniform_(self.lora_down.weight, a=1)
|
98 |
+
nn.init.zeros_(self.lora_up.weight)
|
99 |
+
|
100 |
+
self.multiplier = multiplier
|
101 |
+
self.org_module = org_module # remove in applying
|
102 |
+
|
103 |
+
def apply_to(self):
|
104 |
+
self.org_forward = self.org_module.forward
|
105 |
+
self.org_module.forward = self.forward
|
106 |
+
del self.org_module
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
return (
|
110 |
+
self.org_forward(x)
|
111 |
+
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
class LoRANetwork(nn.Module):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
unet: UNet2DConditionModel,
|
119 |
+
rank: int = 4,
|
120 |
+
multiplier: float = 1.0,
|
121 |
+
alpha: float = 1.0,
|
122 |
+
train_method: TRAINING_METHODS = "full",
|
123 |
+
) -> None:
|
124 |
+
super().__init__()
|
125 |
+
self.lora_scale = 1
|
126 |
+
self.multiplier = multiplier
|
127 |
+
self.lora_dim = rank
|
128 |
+
self.alpha = alpha
|
129 |
+
|
130 |
+
# LoRAใฎใฟ
|
131 |
+
self.module = LoRAModule
|
132 |
+
|
133 |
+
# unetใฎloraใไฝใ
|
134 |
+
self.unet_loras = self.create_modules(
|
135 |
+
LORA_PREFIX_UNET,
|
136 |
+
unet,
|
137 |
+
DEFAULT_TARGET_REPLACE,
|
138 |
+
self.lora_dim,
|
139 |
+
self.multiplier,
|
140 |
+
train_method=train_method,
|
141 |
+
)
|
142 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
143 |
+
|
144 |
+
# assertion ๅๅใฎ่ขซใใใชใใ็ขบ่ชใใฆใใใใใ
|
145 |
+
lora_names = set()
|
146 |
+
for lora in self.unet_loras:
|
147 |
+
assert (
|
148 |
+
lora.lora_name not in lora_names
|
149 |
+
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
|
150 |
+
lora_names.add(lora.lora_name)
|
151 |
+
|
152 |
+
# ้ฉ็จใใ
|
153 |
+
for lora in self.unet_loras:
|
154 |
+
lora.apply_to()
|
155 |
+
self.add_module(
|
156 |
+
lora.lora_name,
|
157 |
+
lora,
|
158 |
+
)
|
159 |
+
|
160 |
+
del unet
|
161 |
+
|
162 |
+
torch.cuda.empty_cache()
|
163 |
+
|
164 |
+
def create_modules(
|
165 |
+
self,
|
166 |
+
prefix: str,
|
167 |
+
root_module: nn.Module,
|
168 |
+
target_replace_modules: List[str],
|
169 |
+
rank: int,
|
170 |
+
multiplier: float,
|
171 |
+
train_method: TRAINING_METHODS,
|
172 |
+
) -> list:
|
173 |
+
loras = []
|
174 |
+
names = []
|
175 |
+
for name, module in root_module.named_modules():
|
176 |
+
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention ใจ Time Embed ไปฅๅคๅญฆ็ฟ
|
177 |
+
if "attn2" in name or "time_embed" in name:
|
178 |
+
continue
|
179 |
+
elif train_method == "innoxattn": # Cross Attention ไปฅๅคๅญฆ็ฟ
|
180 |
+
if "attn2" in name:
|
181 |
+
continue
|
182 |
+
elif train_method == "selfattn": # Self Attention ใฎใฟๅญฆ็ฟ
|
183 |
+
if "attn1" not in name:
|
184 |
+
continue
|
185 |
+
elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention ใฎใฟๅญฆ็ฟ
|
186 |
+
if "attn2" not in name:
|
187 |
+
continue
|
188 |
+
elif train_method == "full": # ๅ
จ้จๅญฆ็ฟ
|
189 |
+
pass
|
190 |
+
else:
|
191 |
+
raise NotImplementedError(
|
192 |
+
f"train_method: {train_method} is not implemented."
|
193 |
+
)
|
194 |
+
if module.__class__.__name__ in target_replace_modules:
|
195 |
+
for child_name, child_module in module.named_modules():
|
196 |
+
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
|
197 |
+
if train_method == 'xattn-strict':
|
198 |
+
if 'out' in child_name:
|
199 |
+
continue
|
200 |
+
if train_method == 'noxattn-hspace':
|
201 |
+
if 'mid_block' not in name:
|
202 |
+
continue
|
203 |
+
if train_method == 'noxattn-hspace-last':
|
204 |
+
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
|
205 |
+
continue
|
206 |
+
lora_name = prefix + "." + name + "." + child_name
|
207 |
+
lora_name = lora_name.replace(".", "_")
|
208 |
+
# print(f"{lora_name}")
|
209 |
+
lora = self.module(
|
210 |
+
lora_name, child_module, multiplier, rank, self.alpha
|
211 |
+
)
|
212 |
+
# print(name, child_name)
|
213 |
+
# print(child_module.weight.shape)
|
214 |
+
if lora_name not in names:
|
215 |
+
loras.append(lora)
|
216 |
+
names.append(lora_name)
|
217 |
+
# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
|
218 |
+
return loras
|
219 |
+
|
220 |
+
def prepare_optimizer_params(self):
|
221 |
+
all_params = []
|
222 |
+
|
223 |
+
if self.unet_loras: # ๅฎ่ณชใใใใใชใ
|
224 |
+
params = []
|
225 |
+
[params.extend(lora.parameters()) for lora in self.unet_loras]
|
226 |
+
param_data = {"params": params}
|
227 |
+
all_params.append(param_data)
|
228 |
+
|
229 |
+
return all_params
|
230 |
+
|
231 |
+
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
|
232 |
+
state_dict = self.state_dict()
|
233 |
+
|
234 |
+
if dtype is not None:
|
235 |
+
for key in list(state_dict.keys()):
|
236 |
+
v = state_dict[key]
|
237 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
238 |
+
state_dict[key] = v
|
239 |
+
|
240 |
+
# for key in list(state_dict.keys()):
|
241 |
+
# if not key.startswith("lora"):
|
242 |
+
# # loraไปฅๅค้คๅค
|
243 |
+
# del state_dict[key]
|
244 |
+
|
245 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
246 |
+
save_file(state_dict, file, metadata)
|
247 |
+
else:
|
248 |
+
torch.save(state_dict, file)
|
249 |
+
def set_lora_slider(self, scale):
|
250 |
+
self.lora_scale = scale
|
251 |
+
|
252 |
+
def __enter__(self):
|
253 |
+
for lora in self.unet_loras:
|
254 |
+
lora.multiplier = 1.0 * self.lora_scale
|
255 |
+
|
256 |
+
def __exit__(self, exc_type, exc_value, tb):
|
257 |
+
for lora in self.unet_loras:
|
258 |
+
lora.multiplier = 0
|