|
""" |
|
Usage: |
|
|
|
python upsample_lora_rank.py --repo_id="cocktailpeanut/optimus" \ |
|
--filename="optimus.safetensors" \ |
|
--new_lora_path="optimus_16.safetensors" \ |
|
--new_rank=16 |
|
""" |
|
|
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import safetensors.torch |
|
import fire |
|
|
|
|
|
def orthogonal_extension(matrix, target_rows): |
|
""" |
|
Extends the given matrix to have target_rows rows by adding orthogonal rows. |
|
|
|
Args: |
|
matrix (torch.Tensor): Original matrix of shape [original_rows, columns]. |
|
target_rows (int): Desired number of rows. |
|
|
|
Returns: |
|
extended_matrix (torch.Tensor): Matrix of shape [target_rows, columns]. |
|
""" |
|
original_rows, cols = matrix.shape |
|
assert target_rows >= original_rows, "Target rows must be greater than or equal to original rows." |
|
|
|
|
|
Q, R = torch.linalg.qr(matrix.T, mode="reduced") |
|
Q = Q.T |
|
|
|
|
|
if target_rows > original_rows: |
|
additional_rows = target_rows - original_rows |
|
random_matrix = torch.randn(additional_rows, cols, dtype=matrix.dtype, device=matrix.device) |
|
|
|
for i in range(additional_rows): |
|
v = random_matrix[i] |
|
v = v - Q.T @ (Q @ v) |
|
v = v / v.norm() |
|
Q = torch.cat([Q, v.unsqueeze(0)], dim=0) |
|
extended_matrix = Q |
|
return extended_matrix |
|
|
|
|
|
def increase_lora_rank_orthogonal(state_dict, target_rank=16): |
|
""" |
|
Increases the rank of all LoRA matrices in the given state dict using orthogonal extension. |
|
|
|
Args: |
|
state_dict (dict): The state dict containing LoRA matrices. |
|
target_rank (int): Desired higher rank. |
|
|
|
Returns: |
|
new_state_dict (dict): State dict with increased-rank LoRA matrices. |
|
""" |
|
new_state_dict = state_dict.copy() |
|
for key in state_dict.keys(): |
|
if "lora_A.weight" in key: |
|
lora_A_key = key |
|
lora_B_key = key.replace("lora_A.weight", "lora_B.weight") |
|
if lora_B_key in state_dict: |
|
lora_A = state_dict[lora_A_key] |
|
dtype = lora_A.dtype |
|
lora_A = lora_A.to("cuda", torch.float32) |
|
lora_B = state_dict[lora_B_key] |
|
lora_B = lora_B.to("cuda", torch.float32) |
|
|
|
original_rank = lora_A.shape[0] |
|
|
|
|
|
lora_A_new = orthogonal_extension(lora_A, target_rank).to(dtype) |
|
lora_B_new = orthogonal_extension(lora_B.T, target_rank).T.to(dtype) |
|
|
|
|
|
new_state_dict[lora_A_key] = lora_A_new |
|
new_state_dict[lora_B_key] = lora_B_new |
|
|
|
print( |
|
f"Increased rank of {lora_A_key} and {lora_B_key} from {original_rank} to {target_rank} using orthogonal extension" |
|
) |
|
|
|
return new_state_dict |
|
|
|
|
|
def compare_approximation_error(orig_state_dict, new_state_dict): |
|
for key in orig_state_dict: |
|
if "lora_A.weight" in key: |
|
lora_A_key = key |
|
lora_B_key = key.replace("lora_A.weight", "lora_B.weight") |
|
lora_A_old = orig_state_dict[lora_A_key] |
|
lora_B_old = orig_state_dict[lora_B_key] |
|
lora_A_new = new_state_dict[lora_A_key] |
|
lora_B_new = new_state_dict[lora_B_key] |
|
|
|
|
|
delta_W_old = (lora_B_old @ lora_A_old).to("cuda") |
|
|
|
|
|
delta_W_new = lora_B_new @ lora_A_new |
|
|
|
|
|
error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro") |
|
print(f"Relative error for {lora_A_key}: {error.item():.6f}") |
|
|
|
|
|
def main( |
|
repo_id: str, |
|
filename: str, |
|
new_rank: int, |
|
check_error: bool = False, |
|
new_lora_path: str = None, |
|
): |
|
|
|
if new_lora_path is None: |
|
raise ValueError("Please provide a path to serialize the converted state dict.") |
|
|
|
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
original_state_dict = safetensors.torch.load_file(ckpt_path) |
|
new_state_dict = increase_lora_rank_orthogonal(original_state_dict, target_rank=new_rank) |
|
|
|
if check_error: |
|
compare_approximation_error(original_state_dict, new_state_dict) |
|
|
|
new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()} |
|
|
|
safetensors.torch.save_file(new_state_dict, new_lora_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|