Upload upsample_lora_rank.py
Browse files- upsample_lora_rank.py +133 -0
upsample_lora_rank.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
|
4 |
+
python upsample_lora_rank.py --repo_id="cocktailpeanut/optimus" \
|
5 |
+
--filename="optimus.safetensors" \
|
6 |
+
--new_lora_path="optimus_16.safetensors" \
|
7 |
+
--new_rank=16
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
import safetensors.torch
|
13 |
+
import fire
|
14 |
+
|
15 |
+
|
16 |
+
def orthogonal_extension(matrix, target_rows):
|
17 |
+
"""
|
18 |
+
Extends the given matrix to have target_rows rows by adding orthogonal rows.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
matrix (torch.Tensor): Original matrix of shape [original_rows, columns].
|
22 |
+
target_rows (int): Desired number of rows.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
extended_matrix (torch.Tensor): Matrix of shape [target_rows, columns].
|
26 |
+
"""
|
27 |
+
original_rows, cols = matrix.shape
|
28 |
+
assert target_rows >= original_rows, "Target rows must be greater than or equal to original rows."
|
29 |
+
|
30 |
+
# Perform QR decomposition
|
31 |
+
Q, R = torch.linalg.qr(matrix.T, mode="reduced") # Transpose to get [columns, original_rows]
|
32 |
+
Q = Q.T # Back to [original_rows, columns]
|
33 |
+
|
34 |
+
# Generate orthogonal vectors
|
35 |
+
if target_rows > original_rows:
|
36 |
+
additional_rows = target_rows - original_rows
|
37 |
+
random_matrix = torch.randn(additional_rows, cols, dtype=matrix.dtype, device=matrix.device)
|
38 |
+
# Orthogonalize against existing Q
|
39 |
+
for i in range(additional_rows):
|
40 |
+
v = random_matrix[i]
|
41 |
+
v = v - Q.T @ (Q @ v)
|
42 |
+
v = v / v.norm()
|
43 |
+
Q = torch.cat([Q, v.unsqueeze(0)], dim=0)
|
44 |
+
extended_matrix = Q
|
45 |
+
return extended_matrix
|
46 |
+
|
47 |
+
|
48 |
+
def increase_lora_rank_orthogonal(state_dict, target_rank=16):
|
49 |
+
"""
|
50 |
+
Increases the rank of all LoRA matrices in the given state dict using orthogonal extension.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
state_dict (dict): The state dict containing LoRA matrices.
|
54 |
+
target_rank (int): Desired higher rank.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
new_state_dict (dict): State dict with increased-rank LoRA matrices.
|
58 |
+
"""
|
59 |
+
new_state_dict = state_dict.copy()
|
60 |
+
for key in state_dict.keys():
|
61 |
+
if "lora_A.weight" in key:
|
62 |
+
lora_A_key = key
|
63 |
+
lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
|
64 |
+
if lora_B_key in state_dict:
|
65 |
+
lora_A = state_dict[lora_A_key]
|
66 |
+
dtype = lora_A.dtype
|
67 |
+
lora_A = lora_A.to("cuda", torch.float32)
|
68 |
+
lora_B = state_dict[lora_B_key]
|
69 |
+
lora_B = lora_B.to("cuda", torch.float32)
|
70 |
+
|
71 |
+
original_rank = lora_A.shape[0]
|
72 |
+
|
73 |
+
# Extend lora_A and lora_B
|
74 |
+
lora_A_new = orthogonal_extension(lora_A, target_rank).to(dtype)
|
75 |
+
lora_B_new = orthogonal_extension(lora_B.T, target_rank).T.to(dtype) # Transpose to match dimensions
|
76 |
+
|
77 |
+
# Update the state dict
|
78 |
+
new_state_dict[lora_A_key] = lora_A_new
|
79 |
+
new_state_dict[lora_B_key] = lora_B_new
|
80 |
+
|
81 |
+
print(
|
82 |
+
f"Increased rank of {lora_A_key} and {lora_B_key} from {original_rank} to {target_rank} using orthogonal extension"
|
83 |
+
)
|
84 |
+
|
85 |
+
return new_state_dict
|
86 |
+
|
87 |
+
|
88 |
+
def compare_approximation_error(orig_state_dict, new_state_dict):
|
89 |
+
for key in orig_state_dict:
|
90 |
+
if "lora_A.weight" in key:
|
91 |
+
lora_A_key = key
|
92 |
+
lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
|
93 |
+
lora_A_old = orig_state_dict[lora_A_key]
|
94 |
+
lora_B_old = orig_state_dict[lora_B_key]
|
95 |
+
lora_A_new = new_state_dict[lora_A_key]
|
96 |
+
lora_B_new = new_state_dict[lora_B_key]
|
97 |
+
|
98 |
+
# Original delta_W
|
99 |
+
delta_W_old = (lora_B_old @ lora_A_old).to("cuda")
|
100 |
+
|
101 |
+
# Approximated delta_W
|
102 |
+
delta_W_new = lora_B_new @ lora_A_new
|
103 |
+
|
104 |
+
# Compute the approximation error
|
105 |
+
error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro")
|
106 |
+
print(f"Relative error for {lora_A_key}: {error.item():.6f}")
|
107 |
+
|
108 |
+
|
109 |
+
def main(
|
110 |
+
repo_id: str,
|
111 |
+
filename: str,
|
112 |
+
new_rank: int,
|
113 |
+
check_error: bool = False,
|
114 |
+
new_lora_path: str = None,
|
115 |
+
):
|
116 |
+
# ckpt_path = hf_hub_download(repo_id="TheLastBen/The_Hound", filename="sandor_clegane_single_layer.safetensors")
|
117 |
+
if new_lora_path is None:
|
118 |
+
raise ValueError("Please provide a path to serialize the converted state dict.")
|
119 |
+
|
120 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
121 |
+
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
122 |
+
new_state_dict = increase_lora_rank_orthogonal(original_state_dict, target_rank=new_rank)
|
123 |
+
|
124 |
+
if check_error:
|
125 |
+
compare_approximation_error(original_state_dict, new_state_dict)
|
126 |
+
|
127 |
+
new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()}
|
128 |
+
# safetensors.torch.save_file(new_state_dict, "sandor_clegane_single_layer_32.safetensors")
|
129 |
+
safetensors.torch.save_file(new_state_dict, new_lora_path)
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == "__main__":
|
133 |
+
fire.Fire(main)
|