|
""" |
|
Usage: |
|
|
|
Regular SVD: |
|
python svd_low_rank_lora.py --repo_id=glif/how2draw --filename="How2Draw-V2_000002800.safetensors" \ |
|
--new_rank=4 --new_lora_path="How2Draw-V2_000002800_svd.safetensors" |
|
|
|
Randomized SVD: |
|
python svd_low_rank_lora.py --repo_id=glif/how2draw --filename="How2Draw-V2_000002800.safetensors" \ |
|
--new_rank=4 --niter=5 --new_lora_path="How2Draw-V2_000002800_svd.safetensors" |
|
""" |
|
|
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import safetensors.torch |
|
import fire |
|
|
|
|
|
def randomized_svd(matrix, rank, niter=5): |
|
""" |
|
Performs a randomized SVD on the given matrix. |
|
Args: |
|
matrix (torch.Tensor): The input matrix. |
|
rank (int): The target rank. |
|
niter (int): Number of iterations for power method. |
|
Returns: |
|
U (torch.Tensor), S (torch.Tensor), Vh (torch.Tensor) |
|
""" |
|
|
|
omega = torch.randn(matrix.size(1), rank, device=matrix.device) |
|
|
|
|
|
Y = matrix @ omega |
|
|
|
|
|
Q, _ = torch.linalg.qr(Y, mode="reduced") |
|
|
|
|
|
for _ in range(niter): |
|
Z = matrix.T @ Q |
|
Q, _ = torch.linalg.qr(matrix @ Z, mode="reduced") |
|
|
|
|
|
B = Q.T @ matrix |
|
|
|
|
|
Ub, S, Vh = torch.linalg.svd(B, full_matrices=False) |
|
|
|
|
|
U = Q @ Ub |
|
|
|
return U[:, :rank], S[:rank], Vh[:rank, :] |
|
|
|
|
|
def reduce_lora_rank(lora_A, lora_B, niter, new_rank=4): |
|
""" |
|
Reduces the rank of LoRA matrices lora_A and lora_B with SVD, supporting truncated SVD, too. |
|
|
|
Args: |
|
lora_A (torch.Tensor): Original lora_A matrix of shape [original_rank, in_features]. |
|
lora_B (torch.Tensor): Original lora_B matrix of shape [out_features, original_rank]. |
|
niter (int): Number of power iterations for randomized SVD. |
|
new_rank (int): Desired lower rank. |
|
|
|
Returns: |
|
lora_A_new (torch.Tensor): Reduced lora_A matrix of shape [new_rank, in_features]. |
|
lora_B_new (torch.Tensor): Reduced lora_B matrix of shape [out_features, new_rank]. |
|
""" |
|
|
|
dtype = lora_A.dtype |
|
lora_A = lora_A.to("cuda", torch.float32) |
|
lora_B = lora_B.to("cuda", torch.float32) |
|
delta_W = lora_B @ lora_A |
|
|
|
|
|
if niter is None: |
|
U, S, Vh = torch.linalg.svd(delta_W, full_matrices=False) |
|
|
|
if niter: |
|
U, S, Vh = randomized_svd(delta_W, rank=new_rank, niter=niter) |
|
|
|
|
|
U_new = U[:, :new_rank] |
|
S_new = S[:new_rank] |
|
Vh_new = Vh[:new_rank, :] |
|
|
|
|
|
S_sqrt = torch.sqrt(S_new) |
|
|
|
|
|
lora_B_new = U_new * S_sqrt.unsqueeze(0) |
|
lora_A_new = S_sqrt.unsqueeze(1) * Vh_new |
|
|
|
return lora_A_new.to(dtype), lora_B_new.to(dtype) |
|
|
|
|
|
def reduce_lora_rank_state_dict(state_dict, niter, new_rank=4): |
|
""" |
|
Reduces the rank of all LoRA matrices in the given state dict. |
|
|
|
Args: |
|
state_dict (dict): The state dict containing LoRA matrices. |
|
niter (int): Number of power iterations for ranodmized SVD. |
|
new_rank (int): Desired lower rank. |
|
|
|
Returns: |
|
new_state_dict (dict): State dict with reduced-rank LoRA matrices. |
|
""" |
|
new_state_dict = state_dict.copy() |
|
keys = list(state_dict.keys()) |
|
for key in keys: |
|
if "lora_A.weight" in key: |
|
|
|
lora_A_key = key |
|
lora_B_key = key.replace("lora_A.weight", "lora_B.weight") |
|
if lora_B_key in state_dict: |
|
lora_A = state_dict[lora_A_key] |
|
lora_B = state_dict[lora_B_key] |
|
|
|
|
|
lora_A_new, lora_B_new = reduce_lora_rank(lora_A, lora_B, niter=niter, new_rank=new_rank) |
|
|
|
|
|
new_state_dict[lora_A_key] = lora_A_new |
|
new_state_dict[lora_B_key] = lora_B_new |
|
|
|
print(f"Reduced rank of {lora_A_key} and {lora_B_key} to {new_rank}") |
|
|
|
return new_state_dict |
|
|
|
|
|
def compare_approximation_error(orig_state_dict, new_state_dict): |
|
for key in orig_state_dict: |
|
if "lora_A.weight" in key: |
|
lora_A_key = key |
|
lora_B_key = key.replace("lora_A.weight", "lora_B.weight") |
|
lora_A_old = orig_state_dict[lora_A_key] |
|
lora_B_old = orig_state_dict[lora_B_key] |
|
lora_A_new = new_state_dict[lora_A_key] |
|
lora_B_new = new_state_dict[lora_B_key] |
|
|
|
|
|
delta_W_old = (lora_B_old @ lora_A_old).to("cuda") |
|
|
|
|
|
delta_W_new = lora_B_new @ lora_A_new |
|
|
|
|
|
error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro") |
|
print(f"Relative error for {lora_A_key}: {error.item():.6f}") |
|
|
|
|
|
def main( |
|
repo_id: str, |
|
filename: str, |
|
new_rank: int, |
|
niter: int = None, |
|
check_error: bool = False, |
|
new_lora_path: str = None, |
|
): |
|
|
|
if new_lora_path is None: |
|
raise ValueError("Please provide a path to serialize the converted state dict.") |
|
|
|
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
original_state_dict = safetensors.torch.load_file(ckpt_path) |
|
new_state_dict = reduce_lora_rank_state_dict(original_state_dict, niter=niter, new_rank=new_rank) |
|
|
|
if check_error: |
|
compare_approximation_error(original_state_dict, new_state_dict) |
|
|
|
new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()} |
|
|
|
safetensors.torch.save_file(new_state_dict, new_lora_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|