Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from logging import warnings | |
import torch | |
from typing import Union | |
from types import SimpleNamespace | |
from models.unet.unet_3d_condition import UNet3DConditionModel | |
from transformers import CLIPTextModel | |
from .convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20 | |
from .lora import ( | |
extract_lora_ups_down, | |
inject_trainable_lora_extended, | |
save_lora_weight, | |
train_patch_pipe, | |
monkeypatch_or_replace_lora, | |
monkeypatch_or_replace_lora_extended | |
) | |
FILE_BASENAMES = ['unet', 'text_encoder'] | |
LORA_FILE_TYPES = ['.pt', '.safetensors'] | |
CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r'] | |
STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias'] | |
lora_versions = dict( | |
stable_lora = "stable_lora", | |
cloneofsimo = "cloneofsimo" | |
) | |
lora_func_types = dict( | |
loader = "loader", | |
injector = "injector" | |
) | |
lora_args = dict( | |
model = None, | |
loras = None, | |
target_replace_module = [], | |
target_module = [], | |
r = 4, | |
search_class = [torch.nn.Linear], | |
dropout = 0, | |
lora_bias = 'none' | |
) | |
LoraVersions = SimpleNamespace(**lora_versions) | |
LoraFuncTypes = SimpleNamespace(**lora_func_types) | |
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] | |
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] | |
def filter_dict(_dict, keys=[]): | |
if len(keys) == 0: | |
assert "Keys cannot empty for filtering return dict." | |
for k in keys: | |
if k not in lora_args.keys(): | |
assert f"{k} does not exist in available LoRA arguments" | |
return {k: v for k, v in _dict.items() if k in keys} | |
class LoraHandler(object): | |
def __init__( | |
self, | |
version: LORA_VERSIONS = LoraVersions.cloneofsimo, | |
use_unet_lora: bool = False, | |
use_text_lora: bool = False, | |
save_for_webui: bool = False, | |
only_for_webui: bool = False, | |
lora_bias: str = 'none', | |
unet_replace_modules: list = None, | |
text_encoder_replace_modules: list = None | |
): | |
self.version = version | |
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) | |
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) | |
self.lora_bias = lora_bias | |
self.use_unet_lora = use_unet_lora | |
self.use_text_lora = use_text_lora | |
self.save_for_webui = save_for_webui | |
self.only_for_webui = only_for_webui | |
self.unet_replace_modules = unet_replace_modules | |
self.text_encoder_replace_modules = text_encoder_replace_modules | |
self.use_lora = any([use_text_lora, use_unet_lora]) | |
def is_cloneofsimo_lora(self): | |
return self.version == LoraVersions.cloneofsimo | |
def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader): | |
if self.is_cloneofsimo_lora(): | |
if func_type == LoraFuncTypes.loader: | |
return monkeypatch_or_replace_lora_extended | |
if func_type == LoraFuncTypes.injector: | |
return inject_trainable_lora_extended | |
assert "LoRA Version does not exist." | |
def check_lora_ext(self, lora_file: str): | |
return lora_file.endswith(tuple(LORA_FILE_TYPES)) | |
def get_lora_file_path( | |
self, | |
lora_path: str, | |
model: Union[UNet3DConditionModel, CLIPTextModel] | |
): | |
if os.path.exists(lora_path): | |
lora_filenames = [fns for fns in os.listdir(lora_path)] | |
is_lora = self.check_lora_ext(lora_path) | |
is_unet = isinstance(model, UNet3DConditionModel) | |
is_text = isinstance(model, CLIPTextModel) | |
idx = 0 if is_unet else 1 | |
base_name = FILE_BASENAMES[idx] | |
for lora_filename in lora_filenames: | |
is_lora = self.check_lora_ext(lora_filename) | |
if not is_lora: | |
continue | |
if base_name in lora_filename: | |
return os.path.join(lora_path, lora_filename) | |
return None | |
def handle_lora_load(self, file_name:str, lora_loader_args: dict = None): | |
self.lora_loader(**lora_loader_args) | |
print(f"Successfully loaded LoRA from: {file_name}") | |
def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,): | |
try: | |
lora_file = self.get_lora_file_path(lora_path, model) | |
if lora_file is not None: | |
lora_loader_args.update({"lora_path": lora_file}) | |
self.handle_lora_load(lora_file, lora_loader_args) | |
else: | |
print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...") | |
except Exception as e: | |
print(f"An error occurred while loading a LoRA file: {e}") | |
def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale): | |
return_dict = lora_args.copy() | |
if self.is_cloneofsimo_lora(): | |
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) | |
return_dict.update({ | |
"model": model, | |
"loras": self.get_lora_file_path(lora_path, model), | |
"target_replace_module": replace_modules, | |
"r": r, | |
"scale": scale, | |
"dropout_p": dropout, | |
}) | |
return return_dict | |
def do_lora_injection( | |
self, | |
model, | |
replace_modules, | |
bias='none', | |
dropout=0, | |
r=4, | |
lora_loader_args=None, | |
): | |
REPLACE_MODULES = replace_modules | |
params = None | |
negation = None | |
is_injection_hybrid = False | |
if self.is_cloneofsimo_lora(): | |
is_injection_hybrid = True | |
injector_args = lora_loader_args | |
params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended | |
for _up, _down in extract_lora_ups_down( | |
model, | |
target_replace_module=REPLACE_MODULES): | |
if all(x is not None for x in [_up, _down]): | |
print(f"Lora successfully injected into {model.__class__.__name__}.") | |
break | |
return params, negation, is_injection_hybrid | |
return params, negation, is_injection_hybrid | |
def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0): | |
params = None | |
negation = None | |
lora_loader_args = self.get_lora_func_args( | |
lora_path, | |
use_lora, | |
model, | |
replace_modules, | |
r, | |
dropout, | |
self.lora_bias, | |
scale | |
) | |
if use_lora: | |
params, negation, is_injection_hybrid = self.do_lora_injection( | |
model, | |
replace_modules, | |
bias=self.lora_bias, | |
lora_loader_args=lora_loader_args, | |
dropout=dropout, | |
r=r | |
) | |
if not is_injection_hybrid: | |
self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args) | |
params = model if params is None else params | |
return params, negation | |
def save_cloneofsimo_lora(self, model, save_path, step, flag): | |
def save_lora(model, name, condition, replace_modules, step, save_path, flag=None): | |
if condition and replace_modules is not None: | |
save_path = f"{save_path}/{step}_{name}.pt" | |
save_lora_weight(model, save_path, replace_modules, flag) | |
save_lora( | |
model.unet, | |
FILE_BASENAMES[0], | |
self.use_unet_lora, | |
self.unet_replace_modules, | |
step, | |
save_path, | |
flag | |
) | |
save_lora( | |
model.text_encoder, | |
FILE_BASENAMES[1], | |
self.use_text_lora, | |
self.text_encoder_replace_modules, | |
step, | |
save_path, | |
flag | |
) | |
# train_patch_pipe(model, self.use_unet_lora, self.use_text_lora) | |
def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None): | |
save_path = f"{save_path}/lora" | |
os.makedirs(save_path, exist_ok=True) | |
if self.is_cloneofsimo_lora(): | |
if any([self.save_for_webui, self.only_for_webui]): | |
warnings.warn( | |
""" | |
You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention. | |
Only 'stable_lora' is supported for saving to a compatible webui file. | |
""" | |
) | |
self.save_cloneofsimo_lora(model, save_path, step, flag) | |
def inject_spatial_loras(unet, use_unet_lora, lora_unet_dropout, lora_path, lora_rank, spatial_lora_num): | |
lora_managers_spatial = [] | |
unet_lora_params_spatial_list = [] | |
for i in range(spatial_lora_num): | |
lora_manager_spatial = LoraHandler( | |
use_unet_lora=use_unet_lora, | |
unet_replace_modules=["Transformer2DModel"] | |
) | |
lora_managers_spatial.append(lora_manager_spatial) | |
unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model( | |
use_unet_lora, | |
unet, | |
lora_manager_spatial.unet_replace_modules, | |
lora_unet_dropout, | |
lora_path + '/spatial/lora/', | |
r=lora_rank | |
) | |
unet_lora_params_spatial_list.append(unet_lora_params_spatial) | |
return lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_spatial |