File size: 395 Bytes
45b4aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
import argparse
def prune_ckpt(ckpt_path, save_path):
raw = torch.load(ckpt_path, map_location=torch.device('cpu'))
state_dict = raw["state_dict"]
torch.save(state_dict, save_path)
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('--ckpt_path', type=str)
args.add_argument('--save_path', type=str)
args = args.parse_args()
|