Diffusers
English
sayakpaul HF staff commited on
Commit
d69d12f
1 Parent(s): 8bbc175

Upload upsample_lora_rank.py

Browse files
Files changed (1) hide show
  1. upsample_lora_rank.py +133 -0
upsample_lora_rank.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+
4
+ python upsample_lora_rank.py --repo_id="cocktailpeanut/optimus" \
5
+ --filename="optimus.safetensors" \
6
+ --new_lora_path="optimus_16.safetensors" \
7
+ --new_rank=16
8
+ """
9
+
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ import safetensors.torch
13
+ import fire
14
+
15
+
16
+ def orthogonal_extension(matrix, target_rows):
17
+ """
18
+ Extends the given matrix to have target_rows rows by adding orthogonal rows.
19
+
20
+ Args:
21
+ matrix (torch.Tensor): Original matrix of shape [original_rows, columns].
22
+ target_rows (int): Desired number of rows.
23
+
24
+ Returns:
25
+ extended_matrix (torch.Tensor): Matrix of shape [target_rows, columns].
26
+ """
27
+ original_rows, cols = matrix.shape
28
+ assert target_rows >= original_rows, "Target rows must be greater than or equal to original rows."
29
+
30
+ # Perform QR decomposition
31
+ Q, R = torch.linalg.qr(matrix.T, mode="reduced") # Transpose to get [columns, original_rows]
32
+ Q = Q.T # Back to [original_rows, columns]
33
+
34
+ # Generate orthogonal vectors
35
+ if target_rows > original_rows:
36
+ additional_rows = target_rows - original_rows
37
+ random_matrix = torch.randn(additional_rows, cols, dtype=matrix.dtype, device=matrix.device)
38
+ # Orthogonalize against existing Q
39
+ for i in range(additional_rows):
40
+ v = random_matrix[i]
41
+ v = v - Q.T @ (Q @ v)
42
+ v = v / v.norm()
43
+ Q = torch.cat([Q, v.unsqueeze(0)], dim=0)
44
+ extended_matrix = Q
45
+ return extended_matrix
46
+
47
+
48
+ def increase_lora_rank_orthogonal(state_dict, target_rank=16):
49
+ """
50
+ Increases the rank of all LoRA matrices in the given state dict using orthogonal extension.
51
+
52
+ Args:
53
+ state_dict (dict): The state dict containing LoRA matrices.
54
+ target_rank (int): Desired higher rank.
55
+
56
+ Returns:
57
+ new_state_dict (dict): State dict with increased-rank LoRA matrices.
58
+ """
59
+ new_state_dict = state_dict.copy()
60
+ for key in state_dict.keys():
61
+ if "lora_A.weight" in key:
62
+ lora_A_key = key
63
+ lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
64
+ if lora_B_key in state_dict:
65
+ lora_A = state_dict[lora_A_key]
66
+ dtype = lora_A.dtype
67
+ lora_A = lora_A.to("cuda", torch.float32)
68
+ lora_B = state_dict[lora_B_key]
69
+ lora_B = lora_B.to("cuda", torch.float32)
70
+
71
+ original_rank = lora_A.shape[0]
72
+
73
+ # Extend lora_A and lora_B
74
+ lora_A_new = orthogonal_extension(lora_A, target_rank).to(dtype)
75
+ lora_B_new = orthogonal_extension(lora_B.T, target_rank).T.to(dtype) # Transpose to match dimensions
76
+
77
+ # Update the state dict
78
+ new_state_dict[lora_A_key] = lora_A_new
79
+ new_state_dict[lora_B_key] = lora_B_new
80
+
81
+ print(
82
+ f"Increased rank of {lora_A_key} and {lora_B_key} from {original_rank} to {target_rank} using orthogonal extension"
83
+ )
84
+
85
+ return new_state_dict
86
+
87
+
88
+ def compare_approximation_error(orig_state_dict, new_state_dict):
89
+ for key in orig_state_dict:
90
+ if "lora_A.weight" in key:
91
+ lora_A_key = key
92
+ lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
93
+ lora_A_old = orig_state_dict[lora_A_key]
94
+ lora_B_old = orig_state_dict[lora_B_key]
95
+ lora_A_new = new_state_dict[lora_A_key]
96
+ lora_B_new = new_state_dict[lora_B_key]
97
+
98
+ # Original delta_W
99
+ delta_W_old = (lora_B_old @ lora_A_old).to("cuda")
100
+
101
+ # Approximated delta_W
102
+ delta_W_new = lora_B_new @ lora_A_new
103
+
104
+ # Compute the approximation error
105
+ error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro")
106
+ print(f"Relative error for {lora_A_key}: {error.item():.6f}")
107
+
108
+
109
+ def main(
110
+ repo_id: str,
111
+ filename: str,
112
+ new_rank: int,
113
+ check_error: bool = False,
114
+ new_lora_path: str = None,
115
+ ):
116
+ # ckpt_path = hf_hub_download(repo_id="TheLastBen/The_Hound", filename="sandor_clegane_single_layer.safetensors")
117
+ if new_lora_path is None:
118
+ raise ValueError("Please provide a path to serialize the converted state dict.")
119
+
120
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
121
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
122
+ new_state_dict = increase_lora_rank_orthogonal(original_state_dict, target_rank=new_rank)
123
+
124
+ if check_error:
125
+ compare_approximation_error(original_state_dict, new_state_dict)
126
+
127
+ new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()}
128
+ # safetensors.torch.save_file(new_state_dict, "sandor_clegane_single_layer_32.safetensors")
129
+ safetensors.torch.save_file(new_state_dict, new_lora_path)
130
+
131
+
132
+ if __name__ == "__main__":
133
+ fire.Fire(main)