Spaces:
Build error
Build error
from lora_diffusion.cli_lora_add import * | |
from lora_diffusion.lora import * | |
from lora_diffusion.to_ckpt_v2 import * | |
def monkeypatch_or_replace_safeloras(models, safeloras): | |
loras = parse_safeloras(safeloras) | |
for name, (lora, ranks, target) in loras.items(): | |
model = getattr(models, name, None) | |
if not model: | |
print(f"No model provided for {name}, contained in Lora") | |
continue | |
monkeypatch_or_replace_lora_extended(model, lora, target, ranks) | |
def parse_safeloras( | |
safeloras, | |
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: | |
""" | |
Converts a loaded safetensor file that contains a set of module Loras | |
into Parameters and other information | |
Output is a dictionary of { | |
"module name": ( | |
[list of weights], | |
[list of ranks], | |
target_replacement_modules | |
) | |
} | |
""" | |
loras = {} | |
# metadata = safeloras.metadata() | |
metadata = safeloras['metadata'] | |
safeloras_ = safeloras['weights'] | |
get_name = lambda k: k.split(":")[0] | |
keys = list(safeloras_.keys()) | |
keys.sort(key=get_name) | |
for name, module_keys in groupby(keys, get_name): | |
info = metadata.get(name) | |
if not info: | |
raise ValueError( | |
f"Tensor {name} has no metadata - is this a Lora safetensor?" | |
) | |
# Skip Textual Inversion embeds | |
if info == EMBED_FLAG: | |
continue | |
# Handle Loras | |
# Extract the targets | |
target = json.loads(info) | |
# Build the result lists - Python needs us to preallocate lists to insert into them | |
module_keys = list(module_keys) | |
ranks = [4] * (len(module_keys) // 2) | |
weights = [None] * len(module_keys) | |
for key in module_keys: | |
# Split the model name and index out of the key | |
_, idx, direction = key.split(":") | |
idx = int(idx) | |
# Add the rank | |
ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) | |
# Insert the weight into the list | |
idx = idx * 2 + (1 if direction == "down" else 0) | |
# weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) | |
weights[idx] = nn.parameter.Parameter(safeloras_[key]) | |
loras[name] = (weights, ranks, target) | |
return loras | |
def parse_safeloras_embeds( | |
safeloras, | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Converts a loaded safetensor file that contains Textual Inversion embeds into | |
a dictionary of embed_token: Tensor | |
""" | |
embeds = {} | |
metadata = safeloras['metadata'] | |
safeloras_ = safeloras['weights'] | |
for key in safeloras_.keys(): | |
# Only handle Textual Inversion embeds | |
meta=None | |
if key in metadata: | |
meta = metadata[key] | |
if not meta or meta != EMBED_FLAG: | |
continue | |
embeds[key] = safeloras_[key] | |
return embeds | |
def patch_pipe( | |
pipe, | |
maybe_unet_path, | |
token: Optional[str] = None, | |
r: int = 4, | |
patch_unet=True, | |
patch_text=True, | |
patch_ti=True, | |
idempotent_token=True, | |
unet_target_replace_module=DEFAULT_TARGET_REPLACE, | |
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, | |
): | |
safeloras=maybe_unet_path | |
monkeypatch_or_replace_safeloras(pipe, safeloras) | |
tok_dict = parse_safeloras_embeds(safeloras) | |
if patch_ti: | |
apply_learned_embed_in_clip( | |
tok_dict, | |
pipe.text_encoder, | |
pipe.tokenizer, | |
token=token, | |
idempotent=idempotent_token, | |
) | |
return tok_dict | |
def lora_convert(model_path, as_half): | |
""" | |
Modified version of lora_duffusion.to_ckpt_v2.convert_to_ckpt | |
""" | |
assert model_path is not None, "Must provide a model path!" | |
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") | |
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") | |
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") | |
# Convert the UNet model | |
unet_state_dict = torch.load(unet_path, map_location="cpu") | |
unet_state_dict = convert_unet_state_dict(unet_state_dict) | |
unet_state_dict = { | |
"model.diffusion_model." + k: v for k, v in unet_state_dict.items() | |
} | |
# Convert the VAE model | |
vae_state_dict = torch.load(vae_path, map_location="cpu") | |
vae_state_dict = convert_vae_state_dict(vae_state_dict) | |
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} | |
# Convert the text encoder model | |
text_enc_dict = torch.load(text_enc_path, map_location="cpu") | |
text_enc_dict = convert_text_enc_state_dict(text_enc_dict) | |
text_enc_dict = { | |
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items() | |
} | |
# Put together new checkpoint | |
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} | |
if as_half: | |
state_dict = {k: v.half() for k, v in state_dict.items()} | |
return state_dict | |
def merge(path_1: str, | |
path_2: str, | |
alpha_1: float = 0.5, | |
): | |
loaded_pipeline = StableDiffusionPipeline.from_pretrained( | |
path_1, | |
).to("cpu") | |
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False) | |
collapse_lora(loaded_pipeline.unet, alpha_1) | |
collapse_lora(loaded_pipeline.text_encoder, alpha_1) | |
monkeypatch_remove_lora(loaded_pipeline.unet) | |
monkeypatch_remove_lora(loaded_pipeline.text_encoder) | |
_tmp_output = "./merge.tmp" | |
loaded_pipeline.save_pretrained(_tmp_output) | |
state_dict = lora_convert(_tmp_output, as_half=True) | |
# remove the tmp_output folder | |
shutil.rmtree(_tmp_output) | |
keys = sorted(tok_dict.keys()) | |
tok_catted = torch.stack([tok_dict[k] for k in keys]) | |
ret = { | |
"string_to_token": {"*": torch.tensor(265)}, | |
"string_to_param": {"*": tok_catted}, | |
"name": "", | |
} | |
return state_dict, ret |