Diffusers
English
flux-lora-resizing / svd_low_rank_lora.py
sayakpaul's picture
sayakpaul HF staff
Upload folder using huggingface_hub
68307ce verified
raw
history blame
6.27 kB
"""
Usage:
Regular SVD:
python svd_low_rank_lora.py --repo_id=glif/how2draw --filename="How2Draw-V2_000002800.safetensors" \
--new_rank=4 --new_lora_path="How2Draw-V2_000002800_svd.safetensors"
Randomized SVD:
python svd_low_rank_lora.py --repo_id=glif/how2draw --filename="How2Draw-V2_000002800.safetensors" \
--new_rank=4 --niter=5 --new_lora_path="How2Draw-V2_000002800_svd.safetensors"
"""
import torch
from huggingface_hub import hf_hub_download
import safetensors.torch
import fire
def randomized_svd(matrix, rank, niter=5):
"""
Performs a randomized SVD on the given matrix.
Args:
matrix (torch.Tensor): The input matrix.
rank (int): The target rank.
niter (int): Number of iterations for power method.
Returns:
U (torch.Tensor), S (torch.Tensor), Vh (torch.Tensor)
"""
# Step 1: Generate a random Gaussian matrix
omega = torch.randn(matrix.size(1), rank, device=matrix.device)
# Step 2: Form Y = A * Omega
Y = matrix @ omega
# Step 3: Orthonormalize Y using QR decomposition
Q, _ = torch.linalg.qr(Y, mode="reduced")
# Power iteration (optional, improves approximation)
for _ in range(niter):
Z = matrix.T @ Q
Q, _ = torch.linalg.qr(matrix @ Z, mode="reduced")
# Step 4: Compute B = Q^T * A
B = Q.T @ matrix
# Step 5: Compute SVD of the small matrix B
Ub, S, Vh = torch.linalg.svd(B, full_matrices=False)
# Step 6: Compute U = Q * Ub
U = Q @ Ub
return U[:, :rank], S[:rank], Vh[:rank, :]
def reduce_lora_rank(lora_A, lora_B, niter, new_rank=4):
"""
Reduces the rank of LoRA matrices lora_A and lora_B with SVD, supporting truncated SVD, too.
Args:
lora_A (torch.Tensor): Original lora_A matrix of shape [original_rank, in_features].
lora_B (torch.Tensor): Original lora_B matrix of shape [out_features, original_rank].
niter (int): Number of power iterations for randomized SVD.
new_rank (int): Desired lower rank.
Returns:
lora_A_new (torch.Tensor): Reduced lora_A matrix of shape [new_rank, in_features].
lora_B_new (torch.Tensor): Reduced lora_B matrix of shape [out_features, new_rank].
"""
# Compute the low-rank update matrix
dtype = lora_A.dtype
lora_A = lora_A.to("cuda", torch.float32)
lora_B = lora_B.to("cuda", torch.float32)
delta_W = lora_B @ lora_A
# Perform SVD on the update matrix
if niter is None:
U, S, Vh = torch.linalg.svd(delta_W, full_matrices=False)
# Perform randomized SVD
if niter:
U, S, Vh = randomized_svd(delta_W, rank=new_rank, niter=niter)
# Keep only the top 'new_rank' singular values and vectors
U_new = U[:, :new_rank]
S_new = S[:new_rank]
Vh_new = Vh[:new_rank, :]
# Compute the square roots of the singular values
S_sqrt = torch.sqrt(S_new)
# Compute the new lora_B and lora_A matrices
lora_B_new = U_new * S_sqrt.unsqueeze(0) # Shape: [out_features, new_rank]
lora_A_new = S_sqrt.unsqueeze(1) * Vh_new # Shape: [new_rank, in_features]
return lora_A_new.to(dtype), lora_B_new.to(dtype)
def reduce_lora_rank_state_dict(state_dict, niter, new_rank=4):
"""
Reduces the rank of all LoRA matrices in the given state dict.
Args:
state_dict (dict): The state dict containing LoRA matrices.
niter (int): Number of power iterations for ranodmized SVD.
new_rank (int): Desired lower rank.
Returns:
new_state_dict (dict): State dict with reduced-rank LoRA matrices.
"""
new_state_dict = state_dict.copy()
keys = list(state_dict.keys())
for key in keys:
if "lora_A.weight" in key:
# Find the corresponding lora_B
lora_A_key = key
lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
if lora_B_key in state_dict:
lora_A = state_dict[lora_A_key]
lora_B = state_dict[lora_B_key]
# Apply the rank reduction
lora_A_new, lora_B_new = reduce_lora_rank(lora_A, lora_B, niter=niter, new_rank=new_rank)
# Update the state dict
new_state_dict[lora_A_key] = lora_A_new
new_state_dict[lora_B_key] = lora_B_new
print(f"Reduced rank of {lora_A_key} and {lora_B_key} to {new_rank}")
return new_state_dict
def compare_approximation_error(orig_state_dict, new_state_dict):
for key in orig_state_dict:
if "lora_A.weight" in key:
lora_A_key = key
lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
lora_A_old = orig_state_dict[lora_A_key]
lora_B_old = orig_state_dict[lora_B_key]
lora_A_new = new_state_dict[lora_A_key]
lora_B_new = new_state_dict[lora_B_key]
# Original delta_W
delta_W_old = (lora_B_old @ lora_A_old).to("cuda")
# Approximated delta_W
delta_W_new = lora_B_new @ lora_A_new
# Compute the approximation error
error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro")
print(f"Relative error for {lora_A_key}: {error.item():.6f}")
def main(
repo_id: str,
filename: str,
new_rank: int,
niter: int = None,
check_error: bool = False,
new_lora_path: str = None,
):
# ckpt_path = hf_hub_download(repo_id="glif/how2draw", filename="How2Draw-V2_000002800.safetensors")
if new_lora_path is None:
raise ValueError("Please provide a path to serialize the converted state dict.")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
original_state_dict = safetensors.torch.load_file(ckpt_path)
new_state_dict = reduce_lora_rank_state_dict(original_state_dict, niter=niter, new_rank=new_rank)
if check_error:
compare_approximation_error(original_state_dict, new_state_dict)
new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()}
# safetensors.torch.save_file(new_state_dict, "How2Draw-V2_000002800_reduced_sparse.safetensors")
safetensors.torch.save_file(new_state_dict, new_lora_path)
if __name__ == "__main__":
fire.Fire(main)