File size: 438 Bytes
e83bfa8
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

if __name__ == '__main__':
    model_path = "saved_model/11/model.pth"
    output_path = "saved_model/11/model1.pth"
    checkpoint_dict = torch.load(model_path, map_location='cpu')
    checkpoint_dict_new = {}
    for k, v in checkpoint_dict.items():
        if k == "optimizer":
            print("remove optimizer")
            continue
        checkpoint_dict_new[k] = v
    torch.save(checkpoint_dict_new, output_path)