Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
def average_checkpoints(checkpoint_paths): | |
averaged_ckpt = torch.load(checkpoint_paths[-1], map_location=torch.device('cpu')) | |
param_sum_dict = {} | |
for key, value in averaged_ckpt['state_dict'].items(): | |
param_sum_dict[key] = value.clone() | |
num_checkpoints = len(checkpoint_paths) | |
for ckpt_path in checkpoint_paths[:-1]: | |
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
for key, value in checkpoint['state_dict'].items(): | |
param_sum_dict[key] += value | |
for key in param_sum_dict.keys(): | |
param_sum_dict[key] = param_sum_dict[key] / num_checkpoints | |
averaged_ckpt['state_dict'] = param_sum_dict | |
return averaged_ckpt | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description="Averages the weights of multiple transformer model checkpoints.") | |
parser.add_argument('--checkpoint_paths', nargs='+', required=True, | |
help='List of paths to the checkpoints to be averaged. Example: --checkpoint_paths path1 path2 path3') | |
parser.add_argument('--output_path', type=str, required=True,) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = parse_arguments() | |
averaged_state_dict = average_checkpoints(args.checkpoint_paths) | |
torch.save(averaged_state_dict, args.output_path) | |
print(f"Averaged checkpoint saved to {args.output_path}") | |