flaviagiammarino
commited on
Commit
•
5f9e8f9
1
Parent(s):
e734391
Create scripts/pt_model.py
Browse files- scripts/pt_model.py +29 -0
scripts/pt_model.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor
|
2 |
+
from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys
|
3 |
+
|
4 |
+
from segment_anything import sam_model_registry # pip install git+https://github.com/facebookresearch/segment-anything.git
|
5 |
+
|
6 |
+
# load the MedSAM ViT-B model
|
7 |
+
checkpoint = 'medsam_vit_b.pth' # https://drive.google.com/file/d/1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_/view?usp=drive_link
|
8 |
+
pt_model = sam_model_registry['vit_b'](checkpoint)
|
9 |
+
pt_state_dict = pt_model.state_dict()
|
10 |
+
|
11 |
+
# tweak the model's weights to transformers design
|
12 |
+
hf_state_dict = replace_keys(pt_state_dict)
|
13 |
+
|
14 |
+
# save the model
|
15 |
+
hf_model = SamModel(config=SamConfig())
|
16 |
+
hf_model.load_state_dict(hf_state_dict)
|
17 |
+
hf_model.save_pretrained('./')
|
18 |
+
|
19 |
+
# update the processor, inputs are min-max scaled instead of normalized
|
20 |
+
hf_processor = SamProcessor(
|
21 |
+
image_processor=SamImageProcessor(
|
22 |
+
do_normalize=False,
|
23 |
+
image_mean=[0, 0, 0],
|
24 |
+
image_std=[1, 1, 1],
|
25 |
+
)
|
26 |
+
)
|
27 |
+
|
28 |
+
# save the processor
|
29 |
+
hf_processor.save_pretrained('./')
|