flaviagiammarino commited on
Commit
60164d3
1 Parent(s): 7f4cded

Upload 2 files

Browse files
Files changed (2) hide show
  1. scripts/pt_example.png +0 -0
  2. scripts/pt_example.py +42 -0
scripts/pt_example.png ADDED
scripts/pt_example.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ from transformers import SamModel, SamProcessor
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ model = SamModel.from_pretrained("flaviagiammarino/medsam-vit-base")
11
+ processor = SamProcessor.from_pretrained("flaviagiammarino/medsam-vit-base")
12
+
13
+ img_url = "https://raw.githubusercontent.com/bowang-lab/MedSAM/main/assets/img_demo.png"
14
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
15
+ input_boxes = [95, 255, 190, 350]
16
+
17
+ inputs = processor(raw_image, input_boxes=[input_boxes], return_tensors="pt").to(device)
18
+ outputs = model(**inputs, multimask_output=False)
19
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
20
+
21
+ def show_mask(mask, ax, random_color):
22
+ if random_color:
23
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
24
+ else:
25
+ color = np.array([251/255, 252/255, 30/255, 0.6])
26
+ h, w = mask.shape[-2:]
27
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
28
+ ax.imshow(mask_image)
29
+
30
+ def show_box(box, ax):
31
+ x0, y0 = box[0], box[1]
32
+ w, h = box[2] - box[0], box[3] - box[1]
33
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2))
34
+
35
+ plt.imshow(np.array(raw_image))
36
+ ax = plt.gca()
37
+ for mask in masks:
38
+ show_mask(mask, ax=ax, random_color=False)
39
+ show_box(input_boxes, ax)
40
+ plt.axis("off")
41
+ plt.tight_layout()
42
+ plt.show()