flaviagiammarino commited on
Commit
5f9e8f9
1 Parent(s): e734391

Create scripts/pt_model.py

Browse files
Files changed (1) hide show
  1. scripts/pt_model.py +29 -0
scripts/pt_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor
2
+ from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys
3
+
4
+ from segment_anything import sam_model_registry # pip install git+https://github.com/facebookresearch/segment-anything.git
5
+
6
+ # load the MedSAM ViT-B model
7
+ checkpoint = 'medsam_vit_b.pth' # https://drive.google.com/file/d/1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_/view?usp=drive_link
8
+ pt_model = sam_model_registry['vit_b'](checkpoint)
9
+ pt_state_dict = pt_model.state_dict()
10
+
11
+ # tweak the model's weights to transformers design
12
+ hf_state_dict = replace_keys(pt_state_dict)
13
+
14
+ # save the model
15
+ hf_model = SamModel(config=SamConfig())
16
+ hf_model.load_state_dict(hf_state_dict)
17
+ hf_model.save_pretrained('./')
18
+
19
+ # update the processor, inputs are min-max scaled instead of normalized
20
+ hf_processor = SamProcessor(
21
+ image_processor=SamImageProcessor(
22
+ do_normalize=False,
23
+ image_mean=[0, 0, 0],
24
+ image_std=[1, 1, 1],
25
+ )
26
+ )
27
+
28
+ # save the processor
29
+ hf_processor.save_pretrained('./')