File size: 4,875 Bytes
d69d12f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
Usage:
python upsample_lora_rank.py --repo_id="cocktailpeanut/optimus" \
--filename="optimus.safetensors" \
--new_lora_path="optimus_16.safetensors" \
--new_rank=16
"""
import torch
from huggingface_hub import hf_hub_download
import safetensors.torch
import fire
def orthogonal_extension(matrix, target_rows):
"""
Extends the given matrix to have target_rows rows by adding orthogonal rows.
Args:
matrix (torch.Tensor): Original matrix of shape [original_rows, columns].
target_rows (int): Desired number of rows.
Returns:
extended_matrix (torch.Tensor): Matrix of shape [target_rows, columns].
"""
original_rows, cols = matrix.shape
assert target_rows >= original_rows, "Target rows must be greater than or equal to original rows."
# Perform QR decomposition
Q, R = torch.linalg.qr(matrix.T, mode="reduced") # Transpose to get [columns, original_rows]
Q = Q.T # Back to [original_rows, columns]
# Generate orthogonal vectors
if target_rows > original_rows:
additional_rows = target_rows - original_rows
random_matrix = torch.randn(additional_rows, cols, dtype=matrix.dtype, device=matrix.device)
# Orthogonalize against existing Q
for i in range(additional_rows):
v = random_matrix[i]
v = v - Q.T @ (Q @ v)
v = v / v.norm()
Q = torch.cat([Q, v.unsqueeze(0)], dim=0)
extended_matrix = Q
return extended_matrix
def increase_lora_rank_orthogonal(state_dict, target_rank=16):
"""
Increases the rank of all LoRA matrices in the given state dict using orthogonal extension.
Args:
state_dict (dict): The state dict containing LoRA matrices.
target_rank (int): Desired higher rank.
Returns:
new_state_dict (dict): State dict with increased-rank LoRA matrices.
"""
new_state_dict = state_dict.copy()
for key in state_dict.keys():
if "lora_A.weight" in key:
lora_A_key = key
lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
if lora_B_key in state_dict:
lora_A = state_dict[lora_A_key]
dtype = lora_A.dtype
lora_A = lora_A.to("cuda", torch.float32)
lora_B = state_dict[lora_B_key]
lora_B = lora_B.to("cuda", torch.float32)
original_rank = lora_A.shape[0]
# Extend lora_A and lora_B
lora_A_new = orthogonal_extension(lora_A, target_rank).to(dtype)
lora_B_new = orthogonal_extension(lora_B.T, target_rank).T.to(dtype) # Transpose to match dimensions
# Update the state dict
new_state_dict[lora_A_key] = lora_A_new
new_state_dict[lora_B_key] = lora_B_new
print(
f"Increased rank of {lora_A_key} and {lora_B_key} from {original_rank} to {target_rank} using orthogonal extension"
)
return new_state_dict
def compare_approximation_error(orig_state_dict, new_state_dict):
for key in orig_state_dict:
if "lora_A.weight" in key:
lora_A_key = key
lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
lora_A_old = orig_state_dict[lora_A_key]
lora_B_old = orig_state_dict[lora_B_key]
lora_A_new = new_state_dict[lora_A_key]
lora_B_new = new_state_dict[lora_B_key]
# Original delta_W
delta_W_old = (lora_B_old @ lora_A_old).to("cuda")
# Approximated delta_W
delta_W_new = lora_B_new @ lora_A_new
# Compute the approximation error
error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro")
print(f"Relative error for {lora_A_key}: {error.item():.6f}")
def main(
repo_id: str,
filename: str,
new_rank: int,
check_error: bool = False,
new_lora_path: str = None,
):
# ckpt_path = hf_hub_download(repo_id="TheLastBen/The_Hound", filename="sandor_clegane_single_layer.safetensors")
if new_lora_path is None:
raise ValueError("Please provide a path to serialize the converted state dict.")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
original_state_dict = safetensors.torch.load_file(ckpt_path)
new_state_dict = increase_lora_rank_orthogonal(original_state_dict, target_rank=new_rank)
if check_error:
compare_approximation_error(original_state_dict, new_state_dict)
new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()}
# safetensors.torch.save_file(new_state_dict, "sandor_clegane_single_layer_32.safetensors")
safetensors.torch.save_file(new_state_dict, new_lora_path)
if __name__ == "__main__":
fire.Fire(main)
|