from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys from segment_anything import sam_model_registry # pip install git+https://github.com/facebookresearch/segment-anything.git # load the MedSAM ViT-B model checkpoint = 'medsam_vit_b.pth' # https://drive.google.com/drive/folders/1ETWmi4AiniJeWOt6HAsYgTjYv_fkgzoN?usp=drive_link pt_model = sam_model_registry['vit_b'](checkpoint) pt_state_dict = pt_model.state_dict() # tweak the model's weights to transformers design hf_state_dict = replace_keys(pt_state_dict) # save the model hf_model = SamModel(config=SamConfig()) hf_model.load_state_dict(hf_state_dict) hf_model.save_pretrained('./') # update the processor hf_processor = SamProcessor( image_processor=SamImageProcessor( do_normalize=False, image_mean=[0, 0, 0], image_std=[1, 1, 1], resample=3, ) ) # save the processor hf_processor.save_pretrained('./')