Upload 7 files
Browse files- lycoris/__init__.py +8 -0
- lycoris/kohya.py +272 -0
- lycoris/kohya_model_utils.py +1184 -0
- lycoris/kohya_utils.py +48 -0
- lycoris/locon.py +85 -0
- lycoris/loha.py +198 -0
- lycoris/utils.py +380 -0
lycoris/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lycoris import (
|
2 |
+
kohya,
|
3 |
+
kohya_model_utils,
|
4 |
+
kohya_utils,
|
5 |
+
locon,
|
6 |
+
loha,
|
7 |
+
utils,
|
8 |
+
)
|
lycoris/kohya.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# network module for kohya
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
6 |
+
|
7 |
+
import math
|
8 |
+
from warnings import warn
|
9 |
+
import os
|
10 |
+
from typing import List
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .kohya_utils import *
|
14 |
+
from .locon import LoConModule
|
15 |
+
from .loha import LohaModule
|
16 |
+
|
17 |
+
|
18 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
19 |
+
if network_dim is None:
|
20 |
+
network_dim = 4 # default
|
21 |
+
conv_dim = int(kwargs.get('conv_dim', network_dim))
|
22 |
+
conv_alpha = float(kwargs.get('conv_alpha', network_alpha))
|
23 |
+
dropout = float(kwargs.get('dropout', 0.))
|
24 |
+
algo = kwargs.get('algo', 'lora')
|
25 |
+
disable_cp = kwargs.get('disable_conv_cp', False)
|
26 |
+
network_module = {
|
27 |
+
'lora': LoConModule,
|
28 |
+
'loha': LohaModule,
|
29 |
+
}[algo]
|
30 |
+
|
31 |
+
print(f'Using rank adaptation algo: {algo}')
|
32 |
+
|
33 |
+
if (algo == 'loha'
|
34 |
+
and not kwargs.get('no_dim_warn', False)
|
35 |
+
and (network_dim>64 or conv_dim>64)):
|
36 |
+
print('='*20 + 'WARNING' + '='*20)
|
37 |
+
warn(
|
38 |
+
(
|
39 |
+
"You are not supposed to use dim>64 (64*64 = 4096, it already has enough rank)"
|
40 |
+
"in Hadamard Product representation!\n"
|
41 |
+
"Please consider use lower dim or disable this warning with --network_args no_dim_warn=True\n"
|
42 |
+
"If you just want to use high dim loha, please consider use lower lr."
|
43 |
+
),
|
44 |
+
stacklevel=2,
|
45 |
+
)
|
46 |
+
print('='*20 + 'WARNING' + '='*20)
|
47 |
+
|
48 |
+
network = LycorisNetwork(
|
49 |
+
text_encoder, unet,
|
50 |
+
multiplier=multiplier,
|
51 |
+
lora_dim=network_dim, conv_lora_dim=conv_dim,
|
52 |
+
alpha=network_alpha, conv_alpha=conv_alpha,
|
53 |
+
dropout=dropout,
|
54 |
+
use_cp=(not bool(disable_cp)),
|
55 |
+
network_module=network_module
|
56 |
+
)
|
57 |
+
|
58 |
+
return network
|
59 |
+
|
60 |
+
|
61 |
+
class LycorisNetwork(torch.nn.Module):
|
62 |
+
'''
|
63 |
+
LoRA + LoCon
|
64 |
+
'''
|
65 |
+
# Ignore proj_in or proj_out, their channels is only a few.
|
66 |
+
UNET_TARGET_REPLACE_MODULE = [
|
67 |
+
"Transformer2DModel",
|
68 |
+
"Attention",
|
69 |
+
"ResnetBlock2D",
|
70 |
+
"Downsample2D",
|
71 |
+
"Upsample2D"
|
72 |
+
]
|
73 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
74 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
75 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
text_encoder, unet,
|
80 |
+
multiplier=1.0,
|
81 |
+
lora_dim=4, conv_lora_dim=4,
|
82 |
+
alpha=1, conv_alpha=1,
|
83 |
+
use_cp = True,
|
84 |
+
dropout = 0, network_module = LoConModule,
|
85 |
+
) -> None:
|
86 |
+
super().__init__()
|
87 |
+
self.multiplier = multiplier
|
88 |
+
self.lora_dim = lora_dim
|
89 |
+
self.conv_lora_dim = int(conv_lora_dim)
|
90 |
+
if self.conv_lora_dim != self.lora_dim:
|
91 |
+
print('Apply different lora dim for conv layer')
|
92 |
+
print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}')
|
93 |
+
|
94 |
+
self.alpha = alpha
|
95 |
+
self.conv_alpha = float(conv_alpha)
|
96 |
+
if self.alpha != self.conv_alpha:
|
97 |
+
print('Apply different alpha value for conv layer')
|
98 |
+
print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}')
|
99 |
+
|
100 |
+
if 1 >= dropout >= 0:
|
101 |
+
print(f'Use Dropout value: {dropout}')
|
102 |
+
self.dropout = dropout
|
103 |
+
|
104 |
+
# create module instances
|
105 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[network_module]:
|
106 |
+
print('Create LyCORIS Module')
|
107 |
+
loras = []
|
108 |
+
for name, module in root_module.named_modules():
|
109 |
+
if module.__class__.__name__ in target_replace_modules:
|
110 |
+
for child_name, child_module in module.named_modules():
|
111 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
112 |
+
lora_name = lora_name.replace('.', '_')
|
113 |
+
if child_module.__class__.__name__ == 'Linear' and lora_dim>0:
|
114 |
+
lora = network_module(
|
115 |
+
lora_name, child_module, self.multiplier,
|
116 |
+
self.lora_dim, self.alpha, self.dropout, use_cp
|
117 |
+
)
|
118 |
+
elif child_module.__class__.__name__ == 'Conv2d':
|
119 |
+
k_size, *_ = child_module.kernel_size
|
120 |
+
if k_size==1 and lora_dim>0:
|
121 |
+
lora = network_module(
|
122 |
+
lora_name, child_module, self.multiplier,
|
123 |
+
self.lora_dim, self.alpha, self.dropout, use_cp
|
124 |
+
)
|
125 |
+
elif conv_lora_dim>0:
|
126 |
+
lora = network_module(
|
127 |
+
lora_name, child_module, self.multiplier,
|
128 |
+
self.conv_lora_dim, self.conv_alpha, self.dropout, use_cp
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
continue
|
132 |
+
else:
|
133 |
+
continue
|
134 |
+
loras.append(lora)
|
135 |
+
return loras
|
136 |
+
|
137 |
+
self.text_encoder_loras = create_modules(
|
138 |
+
LycorisNetwork.LORA_PREFIX_TEXT_ENCODER,
|
139 |
+
text_encoder,
|
140 |
+
LycorisNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
141 |
+
)
|
142 |
+
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
143 |
+
|
144 |
+
self.unet_loras = create_modules(LycorisNetwork.LORA_PREFIX_UNET, unet, LycorisNetwork.UNET_TARGET_REPLACE_MODULE)
|
145 |
+
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
|
146 |
+
|
147 |
+
self.weights_sd = None
|
148 |
+
|
149 |
+
# assertion
|
150 |
+
names = set()
|
151 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
152 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
153 |
+
names.add(lora.lora_name)
|
154 |
+
|
155 |
+
def set_multiplier(self, multiplier):
|
156 |
+
self.multiplier = multiplier
|
157 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
158 |
+
lora.multiplier = self.multiplier
|
159 |
+
|
160 |
+
def load_weights(self, file):
|
161 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
162 |
+
from safetensors.torch import load_file, safe_open
|
163 |
+
self.weights_sd = load_file(file)
|
164 |
+
else:
|
165 |
+
self.weights_sd = torch.load(file, map_location='cpu')
|
166 |
+
|
167 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
168 |
+
if self.weights_sd:
|
169 |
+
weights_has_text_encoder = weights_has_unet = False
|
170 |
+
for key in self.weights_sd.keys():
|
171 |
+
if key.startswith(LycorisNetwork.LORA_PREFIX_TEXT_ENCODER):
|
172 |
+
weights_has_text_encoder = True
|
173 |
+
elif key.startswith(LycorisNetwork.LORA_PREFIX_UNET):
|
174 |
+
weights_has_unet = True
|
175 |
+
|
176 |
+
if apply_text_encoder is None:
|
177 |
+
apply_text_encoder = weights_has_text_encoder
|
178 |
+
else:
|
179 |
+
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
180 |
+
|
181 |
+
if apply_unet is None:
|
182 |
+
apply_unet = weights_has_unet
|
183 |
+
else:
|
184 |
+
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
185 |
+
else:
|
186 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
187 |
+
|
188 |
+
if apply_text_encoder:
|
189 |
+
print("enable LyCORIS for text encoder")
|
190 |
+
else:
|
191 |
+
self.text_encoder_loras = []
|
192 |
+
|
193 |
+
if apply_unet:
|
194 |
+
print("enable LyCORIS for U-Net")
|
195 |
+
else:
|
196 |
+
self.unet_loras = []
|
197 |
+
|
198 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
199 |
+
lora.apply_to()
|
200 |
+
self.add_module(lora.lora_name, lora)
|
201 |
+
|
202 |
+
if self.weights_sd:
|
203 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
204 |
+
info = self.load_state_dict(self.weights_sd, False)
|
205 |
+
print(f"weights are loaded: {info}")
|
206 |
+
|
207 |
+
def enable_gradient_checkpointing(self):
|
208 |
+
# not supported
|
209 |
+
def make_ckpt(module):
|
210 |
+
if isinstance(module, torch.nn.Module):
|
211 |
+
module.grad_ckpt = True
|
212 |
+
self.apply(make_ckpt)
|
213 |
+
pass
|
214 |
+
|
215 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
216 |
+
def enumerate_params(loras):
|
217 |
+
params = []
|
218 |
+
for lora in loras:
|
219 |
+
params.extend(lora.parameters())
|
220 |
+
return params
|
221 |
+
|
222 |
+
self.requires_grad_(True)
|
223 |
+
all_params = []
|
224 |
+
|
225 |
+
if self.text_encoder_loras:
|
226 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
227 |
+
if text_encoder_lr is not None:
|
228 |
+
param_data['lr'] = text_encoder_lr
|
229 |
+
all_params.append(param_data)
|
230 |
+
|
231 |
+
if self.unet_loras:
|
232 |
+
param_data = {'params': enumerate_params(self.unet_loras)}
|
233 |
+
if unet_lr is not None:
|
234 |
+
param_data['lr'] = unet_lr
|
235 |
+
all_params.append(param_data)
|
236 |
+
|
237 |
+
return all_params
|
238 |
+
|
239 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
240 |
+
self.requires_grad_(True)
|
241 |
+
|
242 |
+
def on_epoch_start(self, text_encoder, unet):
|
243 |
+
self.train()
|
244 |
+
|
245 |
+
def get_trainable_params(self):
|
246 |
+
return self.parameters()
|
247 |
+
|
248 |
+
def save_weights(self, file, dtype, metadata):
|
249 |
+
if metadata is not None and len(metadata) == 0:
|
250 |
+
metadata = None
|
251 |
+
|
252 |
+
state_dict = self.state_dict()
|
253 |
+
|
254 |
+
if dtype is not None:
|
255 |
+
for key in list(state_dict.keys()):
|
256 |
+
v = state_dict[key]
|
257 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
258 |
+
state_dict[key] = v
|
259 |
+
|
260 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
261 |
+
from safetensors.torch import save_file
|
262 |
+
|
263 |
+
# Precalculate model hashes to save time on indexing
|
264 |
+
if metadata is None:
|
265 |
+
metadata = {}
|
266 |
+
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
267 |
+
metadata["sshs_model_hash"] = model_hash
|
268 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
269 |
+
|
270 |
+
save_file(state_dict, file, metadata)
|
271 |
+
else:
|
272 |
+
torch.save(state_dict, file)
|
lycoris/kohya_model_utils.py
ADDED
@@ -0,0 +1,1184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
|
3 |
+
'''
|
4 |
+
# v1: split from train_db_fixed.py.
|
5 |
+
# v2: support safetensors
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
11 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
12 |
+
from safetensors.torch import load_file, save_file
|
13 |
+
|
14 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
15 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
16 |
+
BETA_START = 0.00085
|
17 |
+
BETA_END = 0.0120
|
18 |
+
|
19 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
20 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
21 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
22 |
+
UNET_PARAMS_IMAGE_SIZE = 32 # unused
|
23 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
24 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
25 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
26 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
27 |
+
UNET_PARAMS_NUM_HEADS = 8
|
28 |
+
|
29 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
30 |
+
VAE_PARAMS_RESOLUTION = 256
|
31 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
32 |
+
VAE_PARAMS_OUT_CH = 3
|
33 |
+
VAE_PARAMS_CH = 128
|
34 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
35 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
36 |
+
|
37 |
+
# V2
|
38 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
39 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
40 |
+
|
41 |
+
# Diffusersの設定を読み込むための参照モデル
|
42 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
43 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
44 |
+
|
45 |
+
|
46 |
+
# region StableDiffusion->Diffusersの変換コード
|
47 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
48 |
+
|
49 |
+
|
50 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
51 |
+
"""
|
52 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
53 |
+
"""
|
54 |
+
if n_shave_prefix_segments >= 0:
|
55 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
56 |
+
else:
|
57 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
58 |
+
|
59 |
+
|
60 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
61 |
+
"""
|
62 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
63 |
+
"""
|
64 |
+
mapping = []
|
65 |
+
for old_item in old_list:
|
66 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
67 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
68 |
+
|
69 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
70 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
71 |
+
|
72 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
73 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
74 |
+
|
75 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
76 |
+
|
77 |
+
mapping.append({"old": old_item, "new": new_item})
|
78 |
+
|
79 |
+
return mapping
|
80 |
+
|
81 |
+
|
82 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
83 |
+
"""
|
84 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
85 |
+
"""
|
86 |
+
mapping = []
|
87 |
+
for old_item in old_list:
|
88 |
+
new_item = old_item
|
89 |
+
|
90 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
91 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
92 |
+
|
93 |
+
mapping.append({"old": old_item, "new": new_item})
|
94 |
+
|
95 |
+
return mapping
|
96 |
+
|
97 |
+
|
98 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
99 |
+
"""
|
100 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
101 |
+
"""
|
102 |
+
mapping = []
|
103 |
+
for old_item in old_list:
|
104 |
+
new_item = old_item
|
105 |
+
|
106 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
107 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
108 |
+
|
109 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
110 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
111 |
+
|
112 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
113 |
+
|
114 |
+
mapping.append({"old": old_item, "new": new_item})
|
115 |
+
|
116 |
+
return mapping
|
117 |
+
|
118 |
+
|
119 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
120 |
+
"""
|
121 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
122 |
+
"""
|
123 |
+
mapping = []
|
124 |
+
for old_item in old_list:
|
125 |
+
new_item = old_item
|
126 |
+
|
127 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
128 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
129 |
+
|
130 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
131 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
134 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
137 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
138 |
+
|
139 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
140 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
141 |
+
|
142 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
143 |
+
|
144 |
+
mapping.append({"old": old_item, "new": new_item})
|
145 |
+
|
146 |
+
return mapping
|
147 |
+
|
148 |
+
|
149 |
+
def assign_to_checkpoint(
|
150 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
151 |
+
):
|
152 |
+
"""
|
153 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
154 |
+
to them. It splits attention layers, and takes into account additional replacements
|
155 |
+
that may arise.
|
156 |
+
|
157 |
+
Assigns the weights to the new checkpoint.
|
158 |
+
"""
|
159 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
160 |
+
|
161 |
+
# Splits the attention layers into three variables.
|
162 |
+
if attention_paths_to_split is not None:
|
163 |
+
for path, path_map in attention_paths_to_split.items():
|
164 |
+
old_tensor = old_checkpoint[path]
|
165 |
+
channels = old_tensor.shape[0] // 3
|
166 |
+
|
167 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
168 |
+
|
169 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
170 |
+
|
171 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
172 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
173 |
+
|
174 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
175 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
176 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
177 |
+
|
178 |
+
for path in paths:
|
179 |
+
new_path = path["new"]
|
180 |
+
|
181 |
+
# These have already been assigned
|
182 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
183 |
+
continue
|
184 |
+
|
185 |
+
# Global renaming happens here
|
186 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
187 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
188 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
189 |
+
|
190 |
+
if additional_replacements is not None:
|
191 |
+
for replacement in additional_replacements:
|
192 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
193 |
+
|
194 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
195 |
+
if "proj_attn.weight" in new_path:
|
196 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
197 |
+
else:
|
198 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
199 |
+
|
200 |
+
|
201 |
+
def conv_attn_to_linear(checkpoint):
|
202 |
+
keys = list(checkpoint.keys())
|
203 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
204 |
+
for key in keys:
|
205 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
208 |
+
elif "proj_attn.weight" in key:
|
209 |
+
if checkpoint[key].ndim > 2:
|
210 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
211 |
+
|
212 |
+
|
213 |
+
def linear_transformer_to_conv(checkpoint):
|
214 |
+
keys = list(checkpoint.keys())
|
215 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
216 |
+
for key in keys:
|
217 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
218 |
+
if checkpoint[key].ndim == 2:
|
219 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
220 |
+
|
221 |
+
|
222 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
223 |
+
"""
|
224 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
225 |
+
"""
|
226 |
+
|
227 |
+
# extract state_dict for UNet
|
228 |
+
unet_state_dict = {}
|
229 |
+
unet_key = "model.diffusion_model."
|
230 |
+
keys = list(checkpoint.keys())
|
231 |
+
for key in keys:
|
232 |
+
if key.startswith(unet_key):
|
233 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
234 |
+
|
235 |
+
new_checkpoint = {}
|
236 |
+
|
237 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
238 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
239 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
240 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
243 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
244 |
+
|
245 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
246 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
247 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
248 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
249 |
+
|
250 |
+
# Retrieves the keys for the input blocks only
|
251 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
252 |
+
input_blocks = {
|
253 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
254 |
+
for layer_id in range(num_input_blocks)
|
255 |
+
}
|
256 |
+
|
257 |
+
# Retrieves the keys for the middle blocks only
|
258 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
259 |
+
middle_blocks = {
|
260 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
261 |
+
for layer_id in range(num_middle_blocks)
|
262 |
+
}
|
263 |
+
|
264 |
+
# Retrieves the keys for the output blocks only
|
265 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
266 |
+
output_blocks = {
|
267 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
268 |
+
for layer_id in range(num_output_blocks)
|
269 |
+
}
|
270 |
+
|
271 |
+
for i in range(1, num_input_blocks):
|
272 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
273 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
274 |
+
|
275 |
+
resnets = [
|
276 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
277 |
+
]
|
278 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
279 |
+
|
280 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
281 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
282 |
+
f"input_blocks.{i}.0.op.weight"
|
283 |
+
)
|
284 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
285 |
+
f"input_blocks.{i}.0.op.bias"
|
286 |
+
)
|
287 |
+
|
288 |
+
paths = renew_resnet_paths(resnets)
|
289 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
290 |
+
assign_to_checkpoint(
|
291 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
292 |
+
)
|
293 |
+
|
294 |
+
if len(attentions):
|
295 |
+
paths = renew_attention_paths(attentions)
|
296 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
297 |
+
assign_to_checkpoint(
|
298 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
299 |
+
)
|
300 |
+
|
301 |
+
resnet_0 = middle_blocks[0]
|
302 |
+
attentions = middle_blocks[1]
|
303 |
+
resnet_1 = middle_blocks[2]
|
304 |
+
|
305 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
306 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
+
|
308 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
309 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
310 |
+
|
311 |
+
attentions_paths = renew_attention_paths(attentions)
|
312 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
313 |
+
assign_to_checkpoint(
|
314 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
315 |
+
)
|
316 |
+
|
317 |
+
for i in range(num_output_blocks):
|
318 |
+
block_id = i // (config["layers_per_block"] + 1)
|
319 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
320 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
321 |
+
output_block_list = {}
|
322 |
+
|
323 |
+
for layer in output_block_layers:
|
324 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
325 |
+
if layer_id in output_block_list:
|
326 |
+
output_block_list[layer_id].append(layer_name)
|
327 |
+
else:
|
328 |
+
output_block_list[layer_id] = [layer_name]
|
329 |
+
|
330 |
+
if len(output_block_list) > 1:
|
331 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
332 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
333 |
+
|
334 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
335 |
+
paths = renew_resnet_paths(resnets)
|
336 |
+
|
337 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
338 |
+
assign_to_checkpoint(
|
339 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
340 |
+
)
|
341 |
+
|
342 |
+
# オリジナル:
|
343 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
344 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
345 |
+
|
346 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
347 |
+
for l in output_block_list.values():
|
348 |
+
l.sort()
|
349 |
+
|
350 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
351 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
352 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
353 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
354 |
+
]
|
355 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
356 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
357 |
+
]
|
358 |
+
|
359 |
+
# Clear attentions as they have been attributed above.
|
360 |
+
if len(attentions) == 2:
|
361 |
+
attentions = []
|
362 |
+
|
363 |
+
if len(attentions):
|
364 |
+
paths = renew_attention_paths(attentions)
|
365 |
+
meta_path = {
|
366 |
+
"old": f"output_blocks.{i}.1",
|
367 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
368 |
+
}
|
369 |
+
assign_to_checkpoint(
|
370 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
374 |
+
for path in resnet_0_paths:
|
375 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
376 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
377 |
+
|
378 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
379 |
+
|
380 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
381 |
+
if v2:
|
382 |
+
linear_transformer_to_conv(new_checkpoint)
|
383 |
+
|
384 |
+
return new_checkpoint
|
385 |
+
|
386 |
+
|
387 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
388 |
+
# extract state dict for VAE
|
389 |
+
vae_state_dict = {}
|
390 |
+
vae_key = "first_stage_model."
|
391 |
+
keys = list(checkpoint.keys())
|
392 |
+
for key in keys:
|
393 |
+
if key.startswith(vae_key):
|
394 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
395 |
+
# if len(vae_state_dict) == 0:
|
396 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
397 |
+
# vae_state_dict = checkpoint
|
398 |
+
|
399 |
+
new_checkpoint = {}
|
400 |
+
|
401 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
402 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
403 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
404 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
405 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
406 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
407 |
+
|
408 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
409 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
410 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
411 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
412 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
413 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
414 |
+
|
415 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
416 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
417 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
418 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
419 |
+
|
420 |
+
# Retrieves the keys for the encoder down blocks only
|
421 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
422 |
+
down_blocks = {
|
423 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
424 |
+
}
|
425 |
+
|
426 |
+
# Retrieves the keys for the decoder up blocks only
|
427 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
428 |
+
up_blocks = {
|
429 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
430 |
+
}
|
431 |
+
|
432 |
+
for i in range(num_down_blocks):
|
433 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
434 |
+
|
435 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
436 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
437 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
438 |
+
)
|
439 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
440 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
441 |
+
)
|
442 |
+
|
443 |
+
paths = renew_vae_resnet_paths(resnets)
|
444 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
445 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
446 |
+
|
447 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
448 |
+
num_mid_res_blocks = 2
|
449 |
+
for i in range(1, num_mid_res_blocks + 1):
|
450 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
451 |
+
|
452 |
+
paths = renew_vae_resnet_paths(resnets)
|
453 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
454 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
455 |
+
|
456 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
457 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
458 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
459 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
460 |
+
conv_attn_to_linear(new_checkpoint)
|
461 |
+
|
462 |
+
for i in range(num_up_blocks):
|
463 |
+
block_id = num_up_blocks - 1 - i
|
464 |
+
resnets = [
|
465 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
466 |
+
]
|
467 |
+
|
468 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
470 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
471 |
+
]
|
472 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
473 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
474 |
+
]
|
475 |
+
|
476 |
+
paths = renew_vae_resnet_paths(resnets)
|
477 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
478 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
479 |
+
|
480 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
481 |
+
num_mid_res_blocks = 2
|
482 |
+
for i in range(1, num_mid_res_blocks + 1):
|
483 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
484 |
+
|
485 |
+
paths = renew_vae_resnet_paths(resnets)
|
486 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
487 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
488 |
+
|
489 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
490 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
491 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
492 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
493 |
+
conv_attn_to_linear(new_checkpoint)
|
494 |
+
return new_checkpoint
|
495 |
+
|
496 |
+
|
497 |
+
def create_unet_diffusers_config(v2):
|
498 |
+
"""
|
499 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
500 |
+
"""
|
501 |
+
# unet_params = original_config.model.params.unet_config.params
|
502 |
+
|
503 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
504 |
+
|
505 |
+
down_block_types = []
|
506 |
+
resolution = 1
|
507 |
+
for i in range(len(block_out_channels)):
|
508 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
509 |
+
down_block_types.append(block_type)
|
510 |
+
if i != len(block_out_channels) - 1:
|
511 |
+
resolution *= 2
|
512 |
+
|
513 |
+
up_block_types = []
|
514 |
+
for i in range(len(block_out_channels)):
|
515 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
516 |
+
up_block_types.append(block_type)
|
517 |
+
resolution //= 2
|
518 |
+
|
519 |
+
config = dict(
|
520 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
521 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
522 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
523 |
+
down_block_types=tuple(down_block_types),
|
524 |
+
up_block_types=tuple(up_block_types),
|
525 |
+
block_out_channels=tuple(block_out_channels),
|
526 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
527 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
528 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
529 |
+
)
|
530 |
+
|
531 |
+
return config
|
532 |
+
|
533 |
+
|
534 |
+
def create_vae_diffusers_config():
|
535 |
+
"""
|
536 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
537 |
+
"""
|
538 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
539 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
540 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
541 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
542 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
543 |
+
|
544 |
+
config = dict(
|
545 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
546 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
547 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
548 |
+
down_block_types=tuple(down_block_types),
|
549 |
+
up_block_types=tuple(up_block_types),
|
550 |
+
block_out_channels=tuple(block_out_channels),
|
551 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
552 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
553 |
+
)
|
554 |
+
return config
|
555 |
+
|
556 |
+
|
557 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
558 |
+
keys = list(checkpoint.keys())
|
559 |
+
text_model_dict = {}
|
560 |
+
for key in keys:
|
561 |
+
if key.startswith("cond_stage_model.transformer"):
|
562 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
563 |
+
return text_model_dict
|
564 |
+
|
565 |
+
|
566 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
567 |
+
# 嫌になるくらい違うぞ!
|
568 |
+
def convert_key(key):
|
569 |
+
if not key.startswith("cond_stage_model"):
|
570 |
+
return None
|
571 |
+
|
572 |
+
# common conversion
|
573 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
574 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
575 |
+
|
576 |
+
if "resblocks" in key:
|
577 |
+
# resblocks conversion
|
578 |
+
key = key.replace(".resblocks.", ".layers.")
|
579 |
+
if ".ln_" in key:
|
580 |
+
key = key.replace(".ln_", ".layer_norm")
|
581 |
+
elif ".mlp." in key:
|
582 |
+
key = key.replace(".c_fc.", ".fc1.")
|
583 |
+
key = key.replace(".c_proj.", ".fc2.")
|
584 |
+
elif '.attn.out_proj' in key:
|
585 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
586 |
+
elif '.attn.in_proj' in key:
|
587 |
+
key = None # 特殊なので後で処理する
|
588 |
+
else:
|
589 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
590 |
+
elif '.positional_embedding' in key:
|
591 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
592 |
+
elif '.text_projection' in key:
|
593 |
+
key = None # 使われない???
|
594 |
+
elif '.logit_scale' in key:
|
595 |
+
key = None # 使われない???
|
596 |
+
elif '.token_embedding' in key:
|
597 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
598 |
+
elif '.ln_final' in key:
|
599 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
600 |
+
return key
|
601 |
+
|
602 |
+
keys = list(checkpoint.keys())
|
603 |
+
new_sd = {}
|
604 |
+
for key in keys:
|
605 |
+
# remove resblocks 23
|
606 |
+
if '.resblocks.23.' in key:
|
607 |
+
continue
|
608 |
+
new_key = convert_key(key)
|
609 |
+
if new_key is None:
|
610 |
+
continue
|
611 |
+
new_sd[new_key] = checkpoint[key]
|
612 |
+
|
613 |
+
# attnの変換
|
614 |
+
for key in keys:
|
615 |
+
if '.resblocks.23.' in key:
|
616 |
+
continue
|
617 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
618 |
+
# 三つに分割
|
619 |
+
values = torch.chunk(checkpoint[key], 3)
|
620 |
+
|
621 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
622 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
623 |
+
key_pfx = key_pfx.replace("_weight", "")
|
624 |
+
key_pfx = key_pfx.replace("_bias", "")
|
625 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
626 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
627 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
628 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
629 |
+
|
630 |
+
# rename or add position_ids
|
631 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
632 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
633 |
+
# waifu diffusion v1.4
|
634 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
635 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
636 |
+
else:
|
637 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
638 |
+
|
639 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
640 |
+
return new_sd
|
641 |
+
|
642 |
+
# endregion
|
643 |
+
|
644 |
+
|
645 |
+
# region Diffusers->StableDiffusion の変換コード
|
646 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
647 |
+
|
648 |
+
def conv_transformer_to_linear(checkpoint):
|
649 |
+
keys = list(checkpoint.keys())
|
650 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
651 |
+
for key in keys:
|
652 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
653 |
+
if checkpoint[key].ndim > 2:
|
654 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
655 |
+
|
656 |
+
|
657 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
658 |
+
unet_conversion_map = [
|
659 |
+
# (stable-diffusion, HF Diffusers)
|
660 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
661 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
662 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
663 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
664 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
665 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
666 |
+
("out.0.weight", "conv_norm_out.weight"),
|
667 |
+
("out.0.bias", "conv_norm_out.bias"),
|
668 |
+
("out.2.weight", "conv_out.weight"),
|
669 |
+
("out.2.bias", "conv_out.bias"),
|
670 |
+
]
|
671 |
+
|
672 |
+
unet_conversion_map_resnet = [
|
673 |
+
# (stable-diffusion, HF Diffusers)
|
674 |
+
("in_layers.0", "norm1"),
|
675 |
+
("in_layers.2", "conv1"),
|
676 |
+
("out_layers.0", "norm2"),
|
677 |
+
("out_layers.3", "conv2"),
|
678 |
+
("emb_layers.1", "time_emb_proj"),
|
679 |
+
("skip_connection", "conv_shortcut"),
|
680 |
+
]
|
681 |
+
|
682 |
+
unet_conversion_map_layer = []
|
683 |
+
for i in range(4):
|
684 |
+
# loop over downblocks/upblocks
|
685 |
+
|
686 |
+
for j in range(2):
|
687 |
+
# loop over resnets/attentions for downblocks
|
688 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
689 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
690 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
691 |
+
|
692 |
+
if i < 3:
|
693 |
+
# no attention layers in down_blocks.3
|
694 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
695 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
696 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
697 |
+
|
698 |
+
for j in range(3):
|
699 |
+
# loop over resnets/attentions for upblocks
|
700 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
701 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
702 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
703 |
+
|
704 |
+
if i > 0:
|
705 |
+
# no attention layers in up_blocks.0
|
706 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
707 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
708 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
709 |
+
|
710 |
+
if i < 3:
|
711 |
+
# no downsample in down_blocks.3
|
712 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
713 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
714 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
715 |
+
|
716 |
+
# no upsample in up_blocks.3
|
717 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
718 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
719 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
720 |
+
|
721 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
722 |
+
sd_mid_atn_prefix = "middle_block.1."
|
723 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
724 |
+
|
725 |
+
for j in range(2):
|
726 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
727 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
728 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
729 |
+
|
730 |
+
# buyer beware: this is a *brittle* function,
|
731 |
+
# and correct output requires that all of these pieces interact in
|
732 |
+
# the exact order in which I have arranged them.
|
733 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
734 |
+
for sd_name, hf_name in unet_conversion_map:
|
735 |
+
mapping[hf_name] = sd_name
|
736 |
+
for k, v in mapping.items():
|
737 |
+
if "resnets" in k:
|
738 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
739 |
+
v = v.replace(hf_part, sd_part)
|
740 |
+
mapping[k] = v
|
741 |
+
for k, v in mapping.items():
|
742 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
743 |
+
v = v.replace(hf_part, sd_part)
|
744 |
+
mapping[k] = v
|
745 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
746 |
+
|
747 |
+
if v2:
|
748 |
+
conv_transformer_to_linear(new_state_dict)
|
749 |
+
|
750 |
+
return new_state_dict
|
751 |
+
|
752 |
+
|
753 |
+
# ================#
|
754 |
+
# VAE Conversion #
|
755 |
+
# ================#
|
756 |
+
|
757 |
+
def reshape_weight_for_sd(w):
|
758 |
+
# convert HF linear weights to SD conv2d weights
|
759 |
+
return w.reshape(*w.shape, 1, 1)
|
760 |
+
|
761 |
+
|
762 |
+
def convert_vae_state_dict(vae_state_dict):
|
763 |
+
vae_conversion_map = [
|
764 |
+
# (stable-diffusion, HF Diffusers)
|
765 |
+
("nin_shortcut", "conv_shortcut"),
|
766 |
+
("norm_out", "conv_norm_out"),
|
767 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
768 |
+
]
|
769 |
+
|
770 |
+
for i in range(4):
|
771 |
+
# down_blocks have two resnets
|
772 |
+
for j in range(2):
|
773 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
774 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
775 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
776 |
+
|
777 |
+
if i < 3:
|
778 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
779 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
780 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
781 |
+
|
782 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
783 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
784 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
785 |
+
|
786 |
+
# up_blocks have three resnets
|
787 |
+
# also, up blocks in hf are numbered in reverse from sd
|
788 |
+
for j in range(3):
|
789 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
790 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
791 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
792 |
+
|
793 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
794 |
+
for i in range(2):
|
795 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
796 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
797 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
798 |
+
|
799 |
+
vae_conversion_map_attn = [
|
800 |
+
# (stable-diffusion, HF Diffusers)
|
801 |
+
("norm.", "group_norm."),
|
802 |
+
("q.", "query."),
|
803 |
+
("k.", "key."),
|
804 |
+
("v.", "value."),
|
805 |
+
("proj_out.", "proj_attn."),
|
806 |
+
]
|
807 |
+
|
808 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
809 |
+
for k, v in mapping.items():
|
810 |
+
for sd_part, hf_part in vae_conversion_map:
|
811 |
+
v = v.replace(hf_part, sd_part)
|
812 |
+
mapping[k] = v
|
813 |
+
for k, v in mapping.items():
|
814 |
+
if "attentions" in k:
|
815 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
816 |
+
v = v.replace(hf_part, sd_part)
|
817 |
+
mapping[k] = v
|
818 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
819 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
820 |
+
for k, v in new_state_dict.items():
|
821 |
+
for weight_name in weights_to_convert:
|
822 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
823 |
+
# print(f"Reshaping {k} for SD format")
|
824 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
825 |
+
|
826 |
+
return new_state_dict
|
827 |
+
|
828 |
+
|
829 |
+
# endregion
|
830 |
+
|
831 |
+
# region 自作のモデル読み書きなど
|
832 |
+
|
833 |
+
def is_safetensors(path):
|
834 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
835 |
+
|
836 |
+
|
837 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
838 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
839 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
840 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
841 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
842 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
843 |
+
]
|
844 |
+
|
845 |
+
if is_safetensors(ckpt_path):
|
846 |
+
checkpoint = None
|
847 |
+
state_dict = load_file(ckpt_path, "cpu")
|
848 |
+
else:
|
849 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
850 |
+
if "state_dict" in checkpoint:
|
851 |
+
state_dict = checkpoint["state_dict"]
|
852 |
+
else:
|
853 |
+
state_dict = checkpoint
|
854 |
+
checkpoint = None
|
855 |
+
|
856 |
+
key_reps = []
|
857 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
858 |
+
for key in state_dict.keys():
|
859 |
+
if key.startswith(rep_from):
|
860 |
+
new_key = rep_to + key[len(rep_from):]
|
861 |
+
key_reps.append((key, new_key))
|
862 |
+
|
863 |
+
for key, new_key in key_reps:
|
864 |
+
state_dict[new_key] = state_dict[key]
|
865 |
+
del state_dict[key]
|
866 |
+
|
867 |
+
return checkpoint, state_dict
|
868 |
+
|
869 |
+
|
870 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
871 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
872 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
873 |
+
if dtype is not None:
|
874 |
+
for k, v in state_dict.items():
|
875 |
+
if type(v) is torch.Tensor:
|
876 |
+
state_dict[k] = v.to(dtype)
|
877 |
+
|
878 |
+
# Convert the UNet2DConditionModel model.
|
879 |
+
unet_config = create_unet_diffusers_config(v2)
|
880 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
881 |
+
|
882 |
+
unet = UNet2DConditionModel(**unet_config)
|
883 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
884 |
+
print("loading u-net:", info)
|
885 |
+
|
886 |
+
# Convert the VAE model.
|
887 |
+
vae_config = create_vae_diffusers_config()
|
888 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
889 |
+
|
890 |
+
vae = AutoencoderKL(**vae_config)
|
891 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
892 |
+
print("loading vae:", info)
|
893 |
+
|
894 |
+
# convert text_model
|
895 |
+
if v2:
|
896 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
897 |
+
cfg = CLIPTextConfig(
|
898 |
+
vocab_size=49408,
|
899 |
+
hidden_size=1024,
|
900 |
+
intermediate_size=4096,
|
901 |
+
num_hidden_layers=23,
|
902 |
+
num_attention_heads=16,
|
903 |
+
max_position_embeddings=77,
|
904 |
+
hidden_act="gelu",
|
905 |
+
layer_norm_eps=1e-05,
|
906 |
+
dropout=0.0,
|
907 |
+
attention_dropout=0.0,
|
908 |
+
initializer_range=0.02,
|
909 |
+
initializer_factor=1.0,
|
910 |
+
pad_token_id=1,
|
911 |
+
bos_token_id=0,
|
912 |
+
eos_token_id=2,
|
913 |
+
model_type="clip_text_model",
|
914 |
+
projection_dim=512,
|
915 |
+
torch_dtype="float32",
|
916 |
+
transformers_version="4.25.0.dev0",
|
917 |
+
)
|
918 |
+
text_model = CLIPTextModel._from_config(cfg)
|
919 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
920 |
+
else:
|
921 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
922 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
923 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
924 |
+
print("loading text encoder:", info)
|
925 |
+
|
926 |
+
return text_model, vae, unet
|
927 |
+
|
928 |
+
|
929 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
930 |
+
def convert_key(key):
|
931 |
+
# position_idsの除去
|
932 |
+
if ".position_ids" in key:
|
933 |
+
return None
|
934 |
+
|
935 |
+
# common
|
936 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
937 |
+
key = key.replace("text_model.", "")
|
938 |
+
if "layers" in key:
|
939 |
+
# resblocks conversion
|
940 |
+
key = key.replace(".layers.", ".resblocks.")
|
941 |
+
if ".layer_norm" in key:
|
942 |
+
key = key.replace(".layer_norm", ".ln_")
|
943 |
+
elif ".mlp." in key:
|
944 |
+
key = key.replace(".fc1.", ".c_fc.")
|
945 |
+
key = key.replace(".fc2.", ".c_proj.")
|
946 |
+
elif '.self_attn.out_proj' in key:
|
947 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
948 |
+
elif '.self_attn.' in key:
|
949 |
+
key = None # 特殊なので後で処理する
|
950 |
+
else:
|
951 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
952 |
+
elif '.position_embedding' in key:
|
953 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
954 |
+
elif '.token_embedding' in key:
|
955 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
956 |
+
elif 'final_layer_norm' in key:
|
957 |
+
key = key.replace("final_layer_norm", "ln_final")
|
958 |
+
return key
|
959 |
+
|
960 |
+
keys = list(checkpoint.keys())
|
961 |
+
new_sd = {}
|
962 |
+
for key in keys:
|
963 |
+
new_key = convert_key(key)
|
964 |
+
if new_key is None:
|
965 |
+
continue
|
966 |
+
new_sd[new_key] = checkpoint[key]
|
967 |
+
|
968 |
+
# attnの変換
|
969 |
+
for key in keys:
|
970 |
+
if 'layers' in key and 'q_proj' in key:
|
971 |
+
# 三つを結合
|
972 |
+
key_q = key
|
973 |
+
key_k = key.replace("q_proj", "k_proj")
|
974 |
+
key_v = key.replace("q_proj", "v_proj")
|
975 |
+
|
976 |
+
value_q = checkpoint[key_q]
|
977 |
+
value_k = checkpoint[key_k]
|
978 |
+
value_v = checkpoint[key_v]
|
979 |
+
value = torch.cat([value_q, value_k, value_v])
|
980 |
+
|
981 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
982 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
983 |
+
new_sd[new_key] = value
|
984 |
+
|
985 |
+
# 最後の層などを捏造するか
|
986 |
+
if make_dummy_weights:
|
987 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
988 |
+
keys = list(new_sd.keys())
|
989 |
+
for key in keys:
|
990 |
+
if key.startswith("transformer.resblocks.22."):
|
991 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
992 |
+
|
993 |
+
# Diffusersに含まれない重みを作っておく
|
994 |
+
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
995 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
996 |
+
|
997 |
+
return new_sd
|
998 |
+
|
999 |
+
|
1000 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
1001 |
+
if ckpt_path is not None:
|
1002 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1003 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1004 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1005 |
+
checkpoint = {}
|
1006 |
+
strict = False
|
1007 |
+
else:
|
1008 |
+
strict = True
|
1009 |
+
if "state_dict" in state_dict:
|
1010 |
+
del state_dict["state_dict"]
|
1011 |
+
else:
|
1012 |
+
# 新しく作る
|
1013 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1014 |
+
checkpoint = {}
|
1015 |
+
state_dict = {}
|
1016 |
+
strict = False
|
1017 |
+
|
1018 |
+
def update_sd(prefix, sd):
|
1019 |
+
for k, v in sd.items():
|
1020 |
+
key = prefix + k
|
1021 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1022 |
+
if save_dtype is not None:
|
1023 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1024 |
+
state_dict[key] = v
|
1025 |
+
|
1026 |
+
# Convert the UNet model
|
1027 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1028 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1029 |
+
|
1030 |
+
# Convert the text encoder model
|
1031 |
+
if v2:
|
1032 |
+
make_dummy = ckpt_path is None # 参照元のcheckpoint���ない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1033 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1034 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1035 |
+
else:
|
1036 |
+
text_enc_dict = text_encoder.state_dict()
|
1037 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1038 |
+
|
1039 |
+
# Convert the VAE
|
1040 |
+
if vae is not None:
|
1041 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1042 |
+
update_sd("first_stage_model.", vae_dict)
|
1043 |
+
|
1044 |
+
# Put together new checkpoint
|
1045 |
+
key_count = len(state_dict.keys())
|
1046 |
+
new_ckpt = {'state_dict': state_dict}
|
1047 |
+
|
1048 |
+
if 'epoch' in checkpoint:
|
1049 |
+
epochs += checkpoint['epoch']
|
1050 |
+
if 'global_step' in checkpoint:
|
1051 |
+
steps += checkpoint['global_step']
|
1052 |
+
|
1053 |
+
new_ckpt['epoch'] = epochs
|
1054 |
+
new_ckpt['global_step'] = steps
|
1055 |
+
|
1056 |
+
if is_safetensors(output_file):
|
1057 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1058 |
+
save_file(state_dict, output_file)
|
1059 |
+
else:
|
1060 |
+
torch.save(new_ckpt, output_file)
|
1061 |
+
|
1062 |
+
return key_count
|
1063 |
+
|
1064 |
+
|
1065 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1066 |
+
if pretrained_model_name_or_path is None:
|
1067 |
+
# load default settings for v1/v2
|
1068 |
+
if v2:
|
1069 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1070 |
+
else:
|
1071 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1072 |
+
|
1073 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1074 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1075 |
+
if vae is None:
|
1076 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1077 |
+
|
1078 |
+
pipeline = StableDiffusionPipeline(
|
1079 |
+
unet=unet,
|
1080 |
+
text_encoder=text_encoder,
|
1081 |
+
vae=vae,
|
1082 |
+
scheduler=scheduler,
|
1083 |
+
tokenizer=tokenizer,
|
1084 |
+
safety_checker=None,
|
1085 |
+
feature_extractor=None,
|
1086 |
+
requires_safety_checker=None,
|
1087 |
+
)
|
1088 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1089 |
+
|
1090 |
+
|
1091 |
+
VAE_PREFIX = "first_stage_model."
|
1092 |
+
|
1093 |
+
|
1094 |
+
def load_vae(vae_id, dtype):
|
1095 |
+
print(f"load VAE: {vae_id}")
|
1096 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1097 |
+
# Diffusers local/remote
|
1098 |
+
try:
|
1099 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1100 |
+
except EnvironmentError as e:
|
1101 |
+
print(f"exception occurs in loading vae: {e}")
|
1102 |
+
print("retry with subfolder='vae'")
|
1103 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1104 |
+
return vae
|
1105 |
+
|
1106 |
+
# local
|
1107 |
+
vae_config = create_vae_diffusers_config()
|
1108 |
+
|
1109 |
+
if vae_id.endswith(".bin"):
|
1110 |
+
# SD 1.5 VAE on Huggingface
|
1111 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1112 |
+
else:
|
1113 |
+
# StableDiffusion
|
1114 |
+
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1115 |
+
else torch.load(vae_id, map_location="cpu"))
|
1116 |
+
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1117 |
+
|
1118 |
+
# vae only or full model
|
1119 |
+
full_model = False
|
1120 |
+
for vae_key in vae_sd:
|
1121 |
+
if vae_key.startswith(VAE_PREFIX):
|
1122 |
+
full_model = True
|
1123 |
+
break
|
1124 |
+
if not full_model:
|
1125 |
+
sd = {}
|
1126 |
+
for key, value in vae_sd.items():
|
1127 |
+
sd[VAE_PREFIX + key] = value
|
1128 |
+
vae_sd = sd
|
1129 |
+
del sd
|
1130 |
+
|
1131 |
+
# Convert the VAE model.
|
1132 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1133 |
+
|
1134 |
+
vae = AutoencoderKL(**vae_config)
|
1135 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1136 |
+
return vae
|
1137 |
+
|
1138 |
+
# endregion
|
1139 |
+
|
1140 |
+
|
1141 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1142 |
+
max_width, max_height = max_reso
|
1143 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1144 |
+
|
1145 |
+
resos = set()
|
1146 |
+
|
1147 |
+
size = int(math.sqrt(max_area)) * divisible
|
1148 |
+
resos.add((size, size))
|
1149 |
+
|
1150 |
+
size = min_size
|
1151 |
+
while size <= max_size:
|
1152 |
+
width = size
|
1153 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1154 |
+
resos.add((width, height))
|
1155 |
+
resos.add((height, width))
|
1156 |
+
|
1157 |
+
# # make additional resos
|
1158 |
+
# if width >= height and width - divisible >= min_size:
|
1159 |
+
# resos.add((width - divisible, height))
|
1160 |
+
# resos.add((height, width - divisible))
|
1161 |
+
# if height >= width and height - divisible >= min_size:
|
1162 |
+
# resos.add((width, height - divisible))
|
1163 |
+
# resos.add((height - divisible, width))
|
1164 |
+
|
1165 |
+
size += divisible
|
1166 |
+
|
1167 |
+
resos = list(resos)
|
1168 |
+
resos.sort()
|
1169 |
+
|
1170 |
+
aspect_ratios = [w / h for w, h in resos]
|
1171 |
+
return resos, aspect_ratios
|
1172 |
+
|
1173 |
+
|
1174 |
+
if __name__ == '__main__':
|
1175 |
+
resos, aspect_ratios = make_bucket_resolutions((512, 768))
|
1176 |
+
print(len(resos))
|
1177 |
+
print(resos)
|
1178 |
+
print(aspect_ratios)
|
1179 |
+
|
1180 |
+
ars = set()
|
1181 |
+
for ar in aspect_ratios:
|
1182 |
+
if ar in ars:
|
1183 |
+
print("error! duplicate ar:", ar)
|
1184 |
+
ars.add(ar)
|
lycoris/kohya_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# part of https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py
|
2 |
+
|
3 |
+
import hashlib
|
4 |
+
import safetensors
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
|
8 |
+
def addnet_hash_legacy(b):
|
9 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
10 |
+
m = hashlib.sha256()
|
11 |
+
|
12 |
+
b.seek(0x100000)
|
13 |
+
m.update(b.read(0x10000))
|
14 |
+
return m.hexdigest()[0:8]
|
15 |
+
|
16 |
+
|
17 |
+
def addnet_hash_safetensors(b):
|
18 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
19 |
+
hash_sha256 = hashlib.sha256()
|
20 |
+
blksize = 1024 * 1024
|
21 |
+
|
22 |
+
b.seek(0)
|
23 |
+
header = b.read(8)
|
24 |
+
n = int.from_bytes(header, "little")
|
25 |
+
|
26 |
+
offset = n + 8
|
27 |
+
b.seek(offset)
|
28 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
29 |
+
hash_sha256.update(chunk)
|
30 |
+
|
31 |
+
return hash_sha256.hexdigest()
|
32 |
+
|
33 |
+
|
34 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
35 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
36 |
+
save time on indexing the model later."""
|
37 |
+
|
38 |
+
# Because writing user metadata to the file can change the result of
|
39 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
40 |
+
# calculating the hash, as they are meant to be immutable
|
41 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
42 |
+
|
43 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
44 |
+
b = BytesIO(bytes)
|
45 |
+
|
46 |
+
model_hash = addnet_hash_safetensors(b)
|
47 |
+
legacy_hash = addnet_hash_legacy(b)
|
48 |
+
return model_hash, legacy_hash
|
lycoris/locon.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class LoConModule(nn.Module):
|
9 |
+
"""
|
10 |
+
modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
lora_name, org_module: nn.Module,
|
16 |
+
multiplier=1.0,
|
17 |
+
lora_dim=4, alpha=1,
|
18 |
+
dropout=0.,
|
19 |
+
use_cp=True,
|
20 |
+
):
|
21 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
22 |
+
super().__init__()
|
23 |
+
self.lora_name = lora_name
|
24 |
+
self.lora_dim = lora_dim
|
25 |
+
self.cp = False
|
26 |
+
|
27 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
28 |
+
# For general LoCon
|
29 |
+
in_dim = org_module.in_channels
|
30 |
+
k_size = org_module.kernel_size
|
31 |
+
stride = org_module.stride
|
32 |
+
padding = org_module.padding
|
33 |
+
out_dim = org_module.out_channels
|
34 |
+
if use_cp and k_size != (1, 1):
|
35 |
+
self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
36 |
+
self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False)
|
37 |
+
self.cp = True
|
38 |
+
else:
|
39 |
+
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
40 |
+
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
41 |
+
else:
|
42 |
+
in_dim = org_module.in_features
|
43 |
+
out_dim = org_module.out_features
|
44 |
+
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
45 |
+
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
46 |
+
self.shape = org_module.weight.shape
|
47 |
+
|
48 |
+
if dropout:
|
49 |
+
self.dropout = nn.Dropout(dropout)
|
50 |
+
else:
|
51 |
+
self.dropout = nn.Identity()
|
52 |
+
|
53 |
+
if type(alpha) == torch.Tensor:
|
54 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
55 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
56 |
+
self.scale = alpha / self.lora_dim
|
57 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
58 |
+
|
59 |
+
# same as microsoft's
|
60 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
61 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
62 |
+
if self.cp:
|
63 |
+
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
|
64 |
+
|
65 |
+
self.multiplier = multiplier
|
66 |
+
self.org_module = [org_module]
|
67 |
+
|
68 |
+
def apply_to(self):
|
69 |
+
self.org_forward = self.org_module[0].forward
|
70 |
+
self.org_module[0].forward = self.forward
|
71 |
+
|
72 |
+
def make_weight(self):
|
73 |
+
wa = self.lora_up.weight
|
74 |
+
wb = self.lora_down.weight
|
75 |
+
return (wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1)).view(self.shape)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
if self.cp:
|
79 |
+
return self.org_forward(x) + self.dropout(
|
80 |
+
self.lora_up(self.lora_mid(self.lora_down(x)))* self.multiplier * self.scale
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
return self.org_forward(x) + self.dropout(
|
84 |
+
self.lora_up(self.lora_down(x))* self.multiplier * self.scale
|
85 |
+
)
|
lycoris/loha.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class HadaWeight(torch.autograd.Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, orig_weight, w1a, w1b, w2a, w2b, scale=torch.tensor(1)):
|
11 |
+
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
|
12 |
+
diff_weight = ((w1a@w1b)*(w2a@w2b)) * scale
|
13 |
+
return orig_weight.reshape(diff_weight.shape) + diff_weight
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def backward(ctx, grad_out):
|
17 |
+
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
|
18 |
+
grad_out = grad_out * scale
|
19 |
+
temp = grad_out*(w2a@w2b)
|
20 |
+
grad_w1a = temp @ w1b.T
|
21 |
+
grad_w1b = w1a.T @ temp
|
22 |
+
|
23 |
+
temp = grad_out * (w1a@w1b)
|
24 |
+
grad_w2a = temp @ w2b.T
|
25 |
+
grad_w2b = w2a.T @ temp
|
26 |
+
|
27 |
+
del temp
|
28 |
+
return grad_out, grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
|
29 |
+
|
30 |
+
|
31 |
+
class HadaWeightCP(torch.autograd.Function):
|
32 |
+
@staticmethod
|
33 |
+
def forward(ctx, orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale=torch.tensor(1)):
|
34 |
+
ctx.save_for_backward(t1, w1a, w1b, t2, w2a, w2b, scale)
|
35 |
+
|
36 |
+
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', t1, w1b, w1a)
|
37 |
+
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', t2, w2b, w2a)
|
38 |
+
|
39 |
+
return orig_weight + rebuild1*rebuild2*scale
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def backward(ctx, grad_out):
|
43 |
+
(t1, w1a, w1b, t2, w2a, w2b, scale) = ctx.saved_tensors
|
44 |
+
|
45 |
+
grad_out = grad_out*scale
|
46 |
+
|
47 |
+
temp = torch.einsum('i j k l, j r -> i r k l', t2, w2b)
|
48 |
+
rebuild = torch.einsum('i j k l, i r -> r j k l', temp, w2a)
|
49 |
+
|
50 |
+
grad_w = rebuild*grad_out
|
51 |
+
del rebuild
|
52 |
+
|
53 |
+
grad_w1a = torch.einsum('r j k l, i j k l -> r i', temp, grad_w)
|
54 |
+
grad_temp = torch.einsum('i j k l, i r -> r j k l', grad_w, w1a.T)
|
55 |
+
del grad_w, temp
|
56 |
+
|
57 |
+
grad_w1b = torch.einsum('i r k l, i j k l -> r j', t1, grad_temp)
|
58 |
+
grad_t1 = torch.einsum('i j k l, j r -> i r k l', grad_temp, w1b.T)
|
59 |
+
del grad_temp
|
60 |
+
|
61 |
+
temp = torch.einsum('i j k l, j r -> i r k l', t1, w1b)
|
62 |
+
rebuild = torch.einsum('i j k l, i r -> r j k l', temp, w1a)
|
63 |
+
|
64 |
+
grad_w = rebuild*grad_out
|
65 |
+
del rebuild
|
66 |
+
|
67 |
+
grad_w2a = torch.einsum('r j k l, i j k l -> r i', temp, grad_w)
|
68 |
+
grad_temp = torch.einsum('i j k l, i r -> r j k l', grad_w, w2a.T)
|
69 |
+
del grad_w, temp
|
70 |
+
|
71 |
+
grad_w2b = torch.einsum('i r k l, i j k l -> r j', t2, grad_temp)
|
72 |
+
grad_t2 = torch.einsum('i j k l, j r -> i r k l', grad_temp, w2b.T)
|
73 |
+
del grad_temp
|
74 |
+
return grad_out, grad_t1, grad_w1a, grad_w1b, grad_t2, grad_w2a, grad_w2b, None
|
75 |
+
|
76 |
+
|
77 |
+
def make_weight(orig_weight, w1a, w1b, w2a, w2b, scale):
|
78 |
+
return HadaWeight.apply(orig_weight, w1a, w1b, w2a, w2b, scale)
|
79 |
+
|
80 |
+
|
81 |
+
def make_weight_cp(orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale):
|
82 |
+
return HadaWeightCP.apply(orig_weight, t1, w1a, w1b, t2, w2a, w2b, scale)
|
83 |
+
|
84 |
+
|
85 |
+
class LohaModule(nn.Module):
|
86 |
+
"""
|
87 |
+
Hadamard product Implementaion for Low Rank Adaptation
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
lora_name,
|
93 |
+
org_module: nn.Module,
|
94 |
+
multiplier=1.0, lora_dim=4, alpha=1, dropout=0.,
|
95 |
+
use_cp=True,
|
96 |
+
):
|
97 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
98 |
+
super().__init__()
|
99 |
+
self.lora_name = lora_name
|
100 |
+
self.lora_dim = lora_dim
|
101 |
+
self.cp=False
|
102 |
+
|
103 |
+
self.shape = org_module.weight.shape
|
104 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
105 |
+
in_dim = org_module.in_channels
|
106 |
+
k_size = org_module.kernel_size
|
107 |
+
out_dim = org_module.out_channels
|
108 |
+
self.cp = use_cp and k_size!=(1, 1)
|
109 |
+
if self.cp:
|
110 |
+
shape = (out_dim, in_dim, *k_size)
|
111 |
+
else:
|
112 |
+
shape = (out_dim, in_dim*k_size[0]*k_size[1])
|
113 |
+
self.op = F.conv2d
|
114 |
+
self.extra_args = {
|
115 |
+
"stride": org_module.stride,
|
116 |
+
"padding": org_module.padding,
|
117 |
+
"dilation": org_module.dilation,
|
118 |
+
"groups": org_module.groups
|
119 |
+
}
|
120 |
+
else:
|
121 |
+
in_dim = org_module.in_features
|
122 |
+
out_dim = org_module.out_features
|
123 |
+
shape = (out_dim, in_dim)
|
124 |
+
self.op = F.linear
|
125 |
+
self.extra_args = {}
|
126 |
+
|
127 |
+
if self.cp:
|
128 |
+
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
129 |
+
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, shape[0])) # out_dim, 1-mode
|
130 |
+
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1])) # in_dim , 2-mode
|
131 |
+
|
132 |
+
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3]))
|
133 |
+
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0])) # out_dim, 1-mode
|
134 |
+
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1])) # in_dim , 2-mode
|
135 |
+
else:
|
136 |
+
self.hada_w1_a = nn.Parameter(torch.empty(shape[0], lora_dim))
|
137 |
+
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
|
138 |
+
|
139 |
+
self.hada_w2_a = nn.Parameter(torch.empty(shape[0], lora_dim))
|
140 |
+
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1]))
|
141 |
+
|
142 |
+
if dropout:
|
143 |
+
self.dropout = nn.Dropout(dropout)
|
144 |
+
else:
|
145 |
+
self.dropout = nn.Identity()
|
146 |
+
|
147 |
+
if type(alpha) == torch.Tensor:
|
148 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
149 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
150 |
+
self.scale = alpha / self.lora_dim
|
151 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
152 |
+
|
153 |
+
# Need more experiences on init method
|
154 |
+
if self.cp:
|
155 |
+
torch.nn.init.normal_(self.hada_t1, std=0.1)
|
156 |
+
torch.nn.init.normal_(self.hada_t2, std=0.1)
|
157 |
+
torch.nn.init.normal_(self.hada_w1_b, std=1)
|
158 |
+
torch.nn.init.normal_(self.hada_w2_b, std=0.01)
|
159 |
+
torch.nn.init.normal_(self.hada_w1_a, std=1)
|
160 |
+
torch.nn.init.constant_(self.hada_w2_a, 0)
|
161 |
+
|
162 |
+
self.multiplier = multiplier
|
163 |
+
self.org_module = [org_module] # remove in applying
|
164 |
+
self.grad_ckpt = False
|
165 |
+
|
166 |
+
def apply_to(self):
|
167 |
+
self.org_module[0].forward = self.forward
|
168 |
+
|
169 |
+
def get_weight(self):
|
170 |
+
d_weight = self.hada_w1_a @ self.hada_w1_b
|
171 |
+
d_weight *= self.hada_w2_a @ self.hada_w2_b
|
172 |
+
return (d_weight).reshape(self.shape)
|
173 |
+
|
174 |
+
@torch.enable_grad()
|
175 |
+
def forward(self, x):
|
176 |
+
# print(torch.mean(torch.abs(self.orig_w1a.to(x.device) - self.hada_w1_a)), end='\r')
|
177 |
+
if self.cp:
|
178 |
+
weight = make_weight_cp(
|
179 |
+
self.org_module[0].weight.data,
|
180 |
+
self.hada_t1, self.hada_w1_a, self.hada_w1_b,
|
181 |
+
self.hada_t1, self.hada_w2_a, self.hada_w2_b,
|
182 |
+
scale = torch.tensor(self.scale*self.multiplier),
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
weight = make_weight(
|
186 |
+
self.org_module[0].weight.data,
|
187 |
+
self.hada_w1_a, self.hada_w1_b,
|
188 |
+
self.hada_w2_a, self.hada_w2_b,
|
189 |
+
scale = torch.tensor(self.scale*self.multiplier),
|
190 |
+
)
|
191 |
+
|
192 |
+
bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data
|
193 |
+
return self.op(
|
194 |
+
x,
|
195 |
+
weight.view(self.shape),
|
196 |
+
bias,
|
197 |
+
**self.extra_args
|
198 |
+
)
|
lycoris/utils.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
import torch.linalg as linalg
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def make_sparse(t: torch.Tensor, sparsity=0.95):
|
15 |
+
abs_t = torch.abs(t)
|
16 |
+
np_array = abs_t.detach().cpu().numpy()
|
17 |
+
quan = float(np.quantile(np_array, sparsity))
|
18 |
+
sparse_t = t.masked_fill(abs_t < quan, 0)
|
19 |
+
return sparse_t
|
20 |
+
|
21 |
+
|
22 |
+
def extract_conv(
|
23 |
+
weight: Union[torch.Tensor, nn.Parameter],
|
24 |
+
mode = 'fixed',
|
25 |
+
mode_param = 0,
|
26 |
+
device = 'cpu',
|
27 |
+
) -> Tuple[nn.Parameter, nn.Parameter]:
|
28 |
+
weight = weight.to(device)
|
29 |
+
out_ch, in_ch, kernel_size, _ = weight.shape
|
30 |
+
|
31 |
+
U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
|
32 |
+
|
33 |
+
if mode=='fixed':
|
34 |
+
lora_rank = mode_param
|
35 |
+
elif mode=='threshold':
|
36 |
+
assert mode_param>=0
|
37 |
+
lora_rank = torch.sum(S>mode_param)
|
38 |
+
elif mode=='ratio':
|
39 |
+
assert 1>=mode_param>=0
|
40 |
+
min_s = torch.max(S)*mode_param
|
41 |
+
lora_rank = torch.sum(S>min_s)
|
42 |
+
elif mode=='quantile' or mode=='percentile':
|
43 |
+
assert 1>=mode_param>=0
|
44 |
+
s_cum = torch.cumsum(S, dim=0)
|
45 |
+
min_cum_sum = mode_param * torch.sum(S)
|
46 |
+
lora_rank = torch.sum(s_cum<min_cum_sum)
|
47 |
+
else:
|
48 |
+
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
49 |
+
lora_rank = max(1, lora_rank)
|
50 |
+
lora_rank = min(out_ch, in_ch, lora_rank)
|
51 |
+
|
52 |
+
U = U[:, :lora_rank]
|
53 |
+
S = S[:lora_rank]
|
54 |
+
U = U @ torch.diag(S)
|
55 |
+
Vh = Vh[:lora_rank, :]
|
56 |
+
|
57 |
+
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
|
58 |
+
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
59 |
+
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
60 |
+
del U, S, Vh, weight
|
61 |
+
return extract_weight_A, extract_weight_B, diff
|
62 |
+
|
63 |
+
|
64 |
+
def merge_conv(
|
65 |
+
weight_a: Union[torch.Tensor, nn.Parameter],
|
66 |
+
weight_b: Union[torch.Tensor, nn.Parameter],
|
67 |
+
device = 'cpu'
|
68 |
+
):
|
69 |
+
rank, in_ch, kernel_size, k_ = weight_a.shape
|
70 |
+
out_ch, rank_, _, _ = weight_b.shape
|
71 |
+
assert rank == rank_ and kernel_size == k_
|
72 |
+
|
73 |
+
wa = weight_a.to(device)
|
74 |
+
wb = weight_b.to(device)
|
75 |
+
|
76 |
+
if device == 'cpu':
|
77 |
+
wa = wa.float()
|
78 |
+
wb = wb.float()
|
79 |
+
|
80 |
+
merged = wb.reshape(out_ch, -1) @ wa.reshape(rank, -1)
|
81 |
+
weight = merged.reshape(out_ch, in_ch, kernel_size, kernel_size)
|
82 |
+
del wb, wa
|
83 |
+
return weight
|
84 |
+
|
85 |
+
|
86 |
+
def extract_linear(
|
87 |
+
weight: Union[torch.Tensor, nn.Parameter],
|
88 |
+
mode = 'fixed',
|
89 |
+
mode_param = 0,
|
90 |
+
device = 'cpu',
|
91 |
+
) -> Tuple[nn.Parameter, nn.Parameter]:
|
92 |
+
weight = weight.to(device)
|
93 |
+
out_ch, in_ch = weight.shape
|
94 |
+
|
95 |
+
U, S, Vh = linalg.svd(weight)
|
96 |
+
|
97 |
+
if mode=='fixed':
|
98 |
+
lora_rank = mode_param
|
99 |
+
elif mode=='threshold':
|
100 |
+
assert mode_param>=0
|
101 |
+
lora_rank = torch.sum(S>mode_param)
|
102 |
+
elif mode=='ratio':
|
103 |
+
assert 1>=mode_param>=0
|
104 |
+
min_s = torch.max(S)*mode_param
|
105 |
+
lora_rank = torch.sum(S>min_s)
|
106 |
+
elif mode=='quantile' or mode=='percentile':
|
107 |
+
assert 1>=mode_param>=0
|
108 |
+
s_cum = torch.cumsum(S, dim=0)
|
109 |
+
min_cum_sum = mode_param * torch.sum(S)
|
110 |
+
lora_rank = torch.sum(s_cum<min_cum_sum)
|
111 |
+
else:
|
112 |
+
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
113 |
+
lora_rank = max(1, lora_rank)
|
114 |
+
lora_rank = min(out_ch, in_ch, lora_rank)
|
115 |
+
|
116 |
+
U = U[:, :lora_rank]
|
117 |
+
S = S[:lora_rank]
|
118 |
+
U = U @ torch.diag(S)
|
119 |
+
Vh = Vh[:lora_rank, :]
|
120 |
+
|
121 |
+
diff = (weight - U @ Vh).detach()
|
122 |
+
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
123 |
+
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
124 |
+
del U, S, Vh, weight
|
125 |
+
return extract_weight_A, extract_weight_B, diff
|
126 |
+
|
127 |
+
|
128 |
+
def merge_linear(
|
129 |
+
weight_a: Union[torch.Tensor, nn.Parameter],
|
130 |
+
weight_b: Union[torch.Tensor, nn.Parameter],
|
131 |
+
device = 'cpu'
|
132 |
+
):
|
133 |
+
rank, in_ch = weight_a.shape
|
134 |
+
out_ch, rank_ = weight_b.shape
|
135 |
+
assert rank == rank_
|
136 |
+
|
137 |
+
wa = weight_a.to(device)
|
138 |
+
wb = weight_b.to(device)
|
139 |
+
|
140 |
+
if device == 'cpu':
|
141 |
+
wa = wa.float()
|
142 |
+
wb = wb.float()
|
143 |
+
|
144 |
+
weight = wb @ wa
|
145 |
+
del wb, wa
|
146 |
+
return weight
|
147 |
+
|
148 |
+
|
149 |
+
def extract_diff(
|
150 |
+
base_model,
|
151 |
+
db_model,
|
152 |
+
mode = 'fixed',
|
153 |
+
linear_mode_param = 0,
|
154 |
+
conv_mode_param = 0,
|
155 |
+
extract_device = 'cpu',
|
156 |
+
use_bias = False,
|
157 |
+
sparsity = 0.98,
|
158 |
+
small_conv = True
|
159 |
+
):
|
160 |
+
UNET_TARGET_REPLACE_MODULE = [
|
161 |
+
"Transformer2DModel",
|
162 |
+
"Attention",
|
163 |
+
"ResnetBlock2D",
|
164 |
+
"Downsample2D",
|
165 |
+
"Upsample2D"
|
166 |
+
]
|
167 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
168 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
169 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
170 |
+
def make_state_dict(
|
171 |
+
prefix,
|
172 |
+
root_module: torch.nn.Module,
|
173 |
+
target_module: torch.nn.Module,
|
174 |
+
target_replace_modules
|
175 |
+
):
|
176 |
+
loras = {}
|
177 |
+
temp = {}
|
178 |
+
|
179 |
+
for name, module in root_module.named_modules():
|
180 |
+
if module.__class__.__name__ in target_replace_modules:
|
181 |
+
temp[name] = {}
|
182 |
+
for child_name, child_module in module.named_modules():
|
183 |
+
if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}:
|
184 |
+
continue
|
185 |
+
temp[name][child_name] = child_module.weight
|
186 |
+
|
187 |
+
for name, module in tqdm(list(target_module.named_modules())):
|
188 |
+
if name in temp:
|
189 |
+
weights = temp[name]
|
190 |
+
for child_name, child_module in module.named_modules():
|
191 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
192 |
+
lora_name = lora_name.replace('.', '_')
|
193 |
+
|
194 |
+
layer = child_module.__class__.__name__
|
195 |
+
if layer == 'Linear':
|
196 |
+
extract_a, extract_b, diff = extract_linear(
|
197 |
+
(child_module.weight - weights[child_name]),
|
198 |
+
mode,
|
199 |
+
linear_mode_param,
|
200 |
+
device = extract_device,
|
201 |
+
)
|
202 |
+
elif layer == 'Conv2d':
|
203 |
+
is_linear = (child_module.weight.shape[2] == 1
|
204 |
+
and child_module.weight.shape[3] == 1)
|
205 |
+
extract_a, extract_b, diff = extract_conv(
|
206 |
+
(child_module.weight - weights[child_name]),
|
207 |
+
mode,
|
208 |
+
linear_mode_param if is_linear else conv_mode_param,
|
209 |
+
device = extract_device,
|
210 |
+
)
|
211 |
+
if small_conv and not is_linear:
|
212 |
+
dim = extract_a.size(0)
|
213 |
+
extract_c, extract_a, _ = extract_conv(
|
214 |
+
extract_a.transpose(0, 1),
|
215 |
+
'fixed', dim,
|
216 |
+
extract_device
|
217 |
+
)
|
218 |
+
extract_a = extract_a.transpose(0, 1)
|
219 |
+
extract_c = extract_c.transpose(0, 1)
|
220 |
+
loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
|
221 |
+
diff = child_module.weight - torch.einsum(
|
222 |
+
'i j k l, j r, p i -> p r k l',
|
223 |
+
extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
|
224 |
+
).detach().cpu().contiguous()
|
225 |
+
del extract_c
|
226 |
+
else:
|
227 |
+
continue
|
228 |
+
loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
|
229 |
+
loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
|
230 |
+
loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
|
231 |
+
|
232 |
+
if use_bias:
|
233 |
+
diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
|
234 |
+
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
|
235 |
+
|
236 |
+
indices = sparse_diff.indices().to(torch.int16)
|
237 |
+
values = sparse_diff.values().half()
|
238 |
+
loras[f'{lora_name}.bias_indices'] = indices
|
239 |
+
loras[f'{lora_name}.bias_values'] = values
|
240 |
+
loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
|
241 |
+
del extract_a, extract_b, diff
|
242 |
+
return loras
|
243 |
+
|
244 |
+
text_encoder_loras = make_state_dict(
|
245 |
+
LORA_PREFIX_TEXT_ENCODER,
|
246 |
+
base_model[0], db_model[0],
|
247 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
248 |
+
)
|
249 |
+
|
250 |
+
unet_loras = make_state_dict(
|
251 |
+
LORA_PREFIX_UNET,
|
252 |
+
base_model[2], db_model[2],
|
253 |
+
UNET_TARGET_REPLACE_MODULE
|
254 |
+
)
|
255 |
+
print(len(text_encoder_loras), len(unet_loras))
|
256 |
+
return text_encoder_loras|unet_loras
|
257 |
+
|
258 |
+
|
259 |
+
def merge_locon(
|
260 |
+
base_model,
|
261 |
+
locon_state_dict: Dict[str, torch.TensorType],
|
262 |
+
scale: float = 1.0,
|
263 |
+
device = 'cpu'
|
264 |
+
):
|
265 |
+
UNET_TARGET_REPLACE_MODULE = [
|
266 |
+
"Transformer2DModel",
|
267 |
+
"Attention",
|
268 |
+
"ResnetBlock2D",
|
269 |
+
"Downsample2D",
|
270 |
+
"Upsample2D"
|
271 |
+
]
|
272 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
273 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
274 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
275 |
+
def merge(
|
276 |
+
prefix,
|
277 |
+
root_module: torch.nn.Module,
|
278 |
+
target_replace_modules
|
279 |
+
):
|
280 |
+
temp = {}
|
281 |
+
|
282 |
+
for name, module in tqdm(list(root_module.named_modules())):
|
283 |
+
if module.__class__.__name__ in target_replace_modules:
|
284 |
+
temp[name] = {}
|
285 |
+
for child_name, child_module in module.named_modules():
|
286 |
+
layer = child_module.__class__.__name__
|
287 |
+
if layer not in {'Linear', 'Conv2d'}:
|
288 |
+
continue
|
289 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
290 |
+
lora_name = lora_name.replace('.', '_')
|
291 |
+
|
292 |
+
down = locon_state_dict[f'{lora_name}.lora_down.weight'].float()
|
293 |
+
up = locon_state_dict[f'{lora_name}.lora_up.weight'].float()
|
294 |
+
alpha = locon_state_dict[f'{lora_name}.alpha'].float()
|
295 |
+
rank = down.shape[0]
|
296 |
+
|
297 |
+
if layer == 'Conv2d':
|
298 |
+
delta = merge_conv(down, up, device)
|
299 |
+
child_module.weight.requires_grad_(False)
|
300 |
+
child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
|
301 |
+
del delta
|
302 |
+
elif layer == 'Linear':
|
303 |
+
delta = merge_linear(down, up, device)
|
304 |
+
child_module.weight.requires_grad_(False)
|
305 |
+
child_module.weight += (alpha.to(device)/rank * scale * delta).cpu()
|
306 |
+
del delta
|
307 |
+
|
308 |
+
merge(
|
309 |
+
LORA_PREFIX_TEXT_ENCODER,
|
310 |
+
base_model[0],
|
311 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
312 |
+
)
|
313 |
+
merge(
|
314 |
+
LORA_PREFIX_UNET,
|
315 |
+
base_model[2],
|
316 |
+
UNET_TARGET_REPLACE_MODULE
|
317 |
+
)
|
318 |
+
|
319 |
+
|
320 |
+
def merge_loha(
|
321 |
+
base_model,
|
322 |
+
loha_state_dict: Dict[str, torch.TensorType],
|
323 |
+
scale: float = 1.0,
|
324 |
+
device = 'cpu'
|
325 |
+
):
|
326 |
+
UNET_TARGET_REPLACE_MODULE = [
|
327 |
+
"Transformer2DModel",
|
328 |
+
"Attention",
|
329 |
+
"ResnetBlock2D",
|
330 |
+
"Downsample2D",
|
331 |
+
"Upsample2D"
|
332 |
+
]
|
333 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
334 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
335 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
336 |
+
def merge(
|
337 |
+
prefix,
|
338 |
+
root_module: torch.nn.Module,
|
339 |
+
target_replace_modules
|
340 |
+
):
|
341 |
+
temp = {}
|
342 |
+
|
343 |
+
for name, module in tqdm(list(root_module.named_modules())):
|
344 |
+
if module.__class__.__name__ in target_replace_modules:
|
345 |
+
temp[name] = {}
|
346 |
+
for child_name, child_module in module.named_modules():
|
347 |
+
layer = child_module.__class__.__name__
|
348 |
+
if layer not in {'Linear', 'Conv2d'}:
|
349 |
+
continue
|
350 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
351 |
+
lora_name = lora_name.replace('.', '_')
|
352 |
+
|
353 |
+
w1a = loha_state_dict[f'{lora_name}.hada_w1_a'].float().to(device)
|
354 |
+
w1b = loha_state_dict[f'{lora_name}.hada_w1_b'].float().to(device)
|
355 |
+
w2a = loha_state_dict[f'{lora_name}.hada_w2_a'].float().to(device)
|
356 |
+
w2b = loha_state_dict[f'{lora_name}.hada_w2_b'].float().to(device)
|
357 |
+
alpha = loha_state_dict[f'{lora_name}.alpha'].float().to(device)
|
358 |
+
dim = w1b.shape[0]
|
359 |
+
|
360 |
+
delta = (w1a @ w1b) * (w2a @ w2b)
|
361 |
+
delta = delta.reshape(child_module.weight.shape)
|
362 |
+
|
363 |
+
if layer == 'Conv2d':
|
364 |
+
child_module.weight.requires_grad_(False)
|
365 |
+
child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
|
366 |
+
elif layer == 'Linear':
|
367 |
+
child_module.weight.requires_grad_(False)
|
368 |
+
child_module.weight += (alpha.to(device)/dim * scale * delta).cpu()
|
369 |
+
del delta
|
370 |
+
|
371 |
+
merge(
|
372 |
+
LORA_PREFIX_TEXT_ENCODER,
|
373 |
+
base_model[0],
|
374 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE
|
375 |
+
)
|
376 |
+
merge(
|
377 |
+
LORA_PREFIX_UNET,
|
378 |
+
base_model[2],
|
379 |
+
UNET_TARGET_REPLACE_MODULE
|
380 |
+
)
|