|
import torch |
|
from diffusers.loaders.lora import LoraLoaderMixin |
|
from typing import Dict, Union |
|
import numpy as np |
|
import imageio |
|
|
|
def load_lora_weights(unet, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name = None, **kwargs): |
|
|
|
if isinstance(pretrained_model_name_or_path_or_dict, dict): |
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() |
|
|
|
|
|
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) |
|
|
|
|
|
state_dict = {name.replace('base_model.model.', ''): param for name, param in state_dict.items()} |
|
|
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) |
|
if not is_correct_format: |
|
raise ValueError("Invalid LoRA checkpoint.") |
|
|
|
low_cpu_mem_usage = True |
|
|
|
LoraLoaderMixin.load_lora_into_unet( |
|
state_dict, |
|
network_alphas=network_alphas, |
|
unet = unet, |
|
low_cpu_mem_usage=low_cpu_mem_usage, |
|
adapter_name=adapter_name, |
|
) |
|
|
|
def save_video(frames, save_path, fps, quality=9): |
|
writer = imageio.get_writer(save_path, fps=fps, quality=quality) |
|
for frame in frames: |
|
frame = np.array(frame) |
|
writer.append_data(frame) |
|
writer.close() |