mart9992 commited on
Commit
062a1ef
1 Parent(s): b93bdbf
Files changed (23) hide show
  1. grounded_sam_demo.py +51 -159
  2. segment_anything/segment_anything.egg-info/PKG-INFO +15 -0
  3. segment_anything/segment_anything.egg-info/SOURCES.txt +26 -0
  4. segment_anything/segment_anything.egg-info/dependency_links.txt +1 -0
  5. segment_anything/segment_anything.egg-info/requires.txt +13 -0
  6. segment_anything/segment_anything.egg-info/top_level.txt +1 -0
  7. segment_anything/segment_anything/__pycache__/__init__.cpython-310.pyc +0 -0
  8. segment_anything/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc +0 -0
  9. segment_anything/segment_anything/__pycache__/build_sam.cpython-310.pyc +0 -0
  10. segment_anything/segment_anything/__pycache__/build_sam_hq.cpython-310.pyc +0 -0
  11. segment_anything/segment_anything/__pycache__/predictor.cpython-310.pyc +0 -0
  12. segment_anything/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
  13. segment_anything/segment_anything/modeling/__pycache__/common.cpython-310.pyc +0 -0
  14. segment_anything/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
  15. segment_anything/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
  16. segment_anything/segment_anything/modeling/__pycache__/mask_decoder_hq.cpython-310.pyc +0 -0
  17. segment_anything/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc +0 -0
  18. segment_anything/segment_anything/modeling/__pycache__/sam.cpython-310.pyc +0 -0
  19. segment_anything/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc +0 -0
  20. segment_anything/segment_anything/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  21. segment_anything/segment_anything/utils/__pycache__/amg.cpython-310.pyc +0 -0
  22. segment_anything/segment_anything/utils/__pycache__/transforms.cpython-310.pyc +0 -0
  23. test.py +16 -12
grounded_sam_demo.py CHANGED
@@ -1,4 +1,5 @@
1
- import argparse
 
2
  import os
3
  import copy
4
 
@@ -16,8 +17,8 @@ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases
16
 
17
  # segment anything
18
  from segment_anything import (
19
- sam_model_registry,
20
- sam_hq_model_registry,
21
  SamPredictor
22
  )
23
  import cv2
@@ -25,27 +26,13 @@ import numpy as np
25
  import matplotlib.pyplot as plt
26
 
27
 
28
- def load_image(image_path):
29
- # load image
30
- image_pil = Image.open(image_path).convert("RGB") # load image
31
-
32
- transform = T.Compose(
33
- [
34
- T.RandomResize([800], max_size=1333),
35
- T.ToTensor(),
36
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
37
- ]
38
- )
39
- image, _ = transform(image_pil, None) # 3, h, w
40
- return image_pil, image
41
-
42
-
43
  def load_model(model_config_path, model_checkpoint_path, device):
44
  args = SLConfig.fromfile(model_config_path)
45
  args.device = device
46
  model = build_model(args)
47
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
48
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
 
49
  print(load_res)
50
  _ = model.eval()
51
  return model
@@ -72,136 +59,38 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
72
  boxes_filt = boxes_filt[filt_mask] # num_filt, 4
73
  logits_filt.shape[0]
74
 
75
- # get phrase
76
- tokenlizer = model.tokenizer
77
- tokenized = tokenlizer(caption)
78
- # build pred
79
- pred_phrases = []
80
- for logit, box in zip(logits_filt, boxes_filt):
81
- pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
82
- if with_logits:
83
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
84
- else:
85
- pred_phrases.append(pred_phrase)
86
-
87
- return boxes_filt, pred_phrases
88
-
89
- def show_mask(mask, ax, random_color=False):
90
- if random_color:
91
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
92
- else:
93
- color = np.array([30/255, 144/255, 255/255, 0.6])
94
- h, w = mask.shape[-2:]
95
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
96
- ax.imshow(mask_image)
97
-
98
-
99
- def show_box(box, ax, label):
100
- x0, y0 = box[0], box[1]
101
- w, h = box[2] - box[0], box[3] - box[1]
102
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
103
- ax.text(x0, y0, label)
104
-
105
-
106
- def save_mask_data(output_dir, mask_list, box_list, label_list):
107
- value = 0 # 0 for background
108
 
109
- mask_img = torch.zeros(mask_list.shape[-2:])
110
- for idx, mask in enumerate(mask_list):
111
- mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
112
- plt.figure(figsize=(10, 10))
113
- plt.imshow(mask_img.numpy())
114
- plt.axis('off')
115
- plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
116
-
117
- json_data = [{
118
- 'value': value,
119
- 'label': 'background'
120
- }]
121
- for label, box in zip(label_list, box_list):
122
- value += 1
123
- name, logit = label.split('(')
124
- logit = logit[:-1] # the last is ')'
125
- json_data.append({
126
- 'value': value,
127
- 'label': name,
128
- 'logit': float(logit),
129
- 'box': box.numpy().tolist(),
130
- })
131
- with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
132
- json.dump(json_data, f)
133
-
134
-
135
- if __name__ == "__main__":
136
-
137
- parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
138
- parser.add_argument("--config", type=str, required=True, help="path to config file")
139
- parser.add_argument(
140
- "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
141
- )
142
- parser.add_argument(
143
- "--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h"
144
- )
145
- parser.add_argument(
146
- "--sam_checkpoint", type=str, required=False, help="path to sam checkpoint file"
147
- )
148
- parser.add_argument(
149
- "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file"
150
- )
151
- parser.add_argument(
152
- "--use_sam_hq", action="store_true", help="using sam-hq for prediction"
153
- )
154
- parser.add_argument("--input_image", type=str, required=True, help="path to image file")
155
- parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
156
- parser.add_argument(
157
- "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
158
- )
159
 
160
- parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
161
- parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
162
-
163
- parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
164
- args = parser.parse_args()
165
-
166
- # cfg
167
- config_file = args.config # change the path of the model config file
168
- grounded_checkpoint = args.grounded_checkpoint # change the path of the model
169
- sam_version = args.sam_version
170
- sam_checkpoint = args.sam_checkpoint
171
- sam_hq_checkpoint = args.sam_hq_checkpoint
172
- use_sam_hq = args.use_sam_hq
173
- image_path = args.input_image
174
- text_prompt = args.text_prompt
175
- output_dir = args.output_dir
176
- box_threshold = args.box_threshold
177
- text_threshold = args.text_threshold
178
- device = args.device
179
-
180
- # make dir
181
- os.makedirs(output_dir, exist_ok=True)
182
- # load image
183
- image_pil, image = load_image(image_path)
184
- # load model
185
- model = load_model(config_file, grounded_checkpoint, device=device)
186
 
187
- # visualize raw image
188
- image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
 
 
 
 
189
 
190
- # run grounding dino model
191
- boxes_filt, pred_phrases = get_grounding_output(
192
- model, image, text_prompt, box_threshold, text_threshold, device=device
193
- )
 
 
 
194
 
195
- # initialize SAM
196
- if use_sam_hq:
197
- predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
198
- else:
199
- predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
200
- image = cv2.imread(image_path)
201
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
202
  predictor.set_image(image)
203
 
204
- size = image_pil.size
205
  H, W = size[1], size[0]
206
  for i in range(boxes_filt.size(0)):
207
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
@@ -209,27 +98,30 @@ if __name__ == "__main__":
209
  boxes_filt[i][2:] += boxes_filt[i][:2]
210
 
211
  boxes_filt = boxes_filt.cpu()
212
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
 
213
 
214
  masks, _, _ = predictor.predict_torch(
215
- point_coords = None,
216
- point_labels = None,
217
- boxes = transformed_boxes.to(device),
218
- multimask_output = False,
219
  )
220
 
221
- # draw output image
222
- plt.figure(figsize=(10, 10))
223
- plt.imshow(image)
224
- for mask in masks:
225
- show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
226
- for box, label in zip(boxes_filt, pred_phrases):
227
- show_box(box.numpy(), plt.gca(), label)
228
 
 
 
229
  plt.axis('off')
230
- plt.savefig(
231
- os.path.join(output_dir, "grounded_sam_output.jpg"),
232
- bbox_inches="tight", dpi=300, pad_inches=0.0
233
- )
234
 
235
- save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
 
 
 
 
 
 
 
1
+ from GroundingDINO.groundingdino.datasets.transforms import Compose, RandomResize, ToTensor, Normalize
2
+ from io import BytesIO
3
  import os
4
  import copy
5
 
 
17
 
18
  # segment anything
19
  from segment_anything import (
20
+ build_sam,
21
+ build_sam_hq,
22
  SamPredictor
23
  )
24
  import cv2
 
26
  import matplotlib.pyplot as plt
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def load_model(model_config_path, model_checkpoint_path, device):
30
  args = SLConfig.fromfile(model_config_path)
31
  args.device = device
32
  model = build_model(args)
33
  checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
34
+ load_res = model.load_state_dict(
35
+ clean_state_dict(checkpoint["model"]), strict=False)
36
  print(load_res)
37
  _ = model.eval()
38
  return model
 
59
  boxes_filt = boxes_filt[filt_mask] # num_filt, 4
60
  logits_filt.shape[0]
61
 
62
+ return boxes_filt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def grounded_sam_demo(input_pil, config_file, grounded_checkpoint, sam_checkpoint,
66
+ text_prompt, box_threshold=0.3, text_threshold=0.25,
67
+ device="cuda"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Convert PIL image to tensor with normalization
70
+ transform = Compose([
71
+ RandomResize([800], max_size=1333),
72
+ ToTensor(),
73
+ Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
74
+ ])
75
 
76
+ if input_pil.mode != "RGB":
77
+ input_pil = input_pil.convert("RGB")
78
+
79
+ image, _ = transform(input_pil, None)
80
+
81
+ # Load model
82
+ model = load_model(config_file, grounded_checkpoint, device=device)
83
 
84
+ # Get grounding dino model output
85
+ boxes_filt = get_grounding_output(
86
+ model, image, text_prompt, box_threshold, text_threshold, device=device)
87
+
88
+ # Initialize SAM
89
+ predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
90
+ image = cv2.cvtColor(np.array(input_pil), cv2.COLOR_RGB2BGR)
91
  predictor.set_image(image)
92
 
93
+ size = input_pil.size
94
  H, W = size[1], size[0]
95
  for i in range(boxes_filt.size(0)):
96
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
 
98
  boxes_filt[i][2:] += boxes_filt[i][:2]
99
 
100
  boxes_filt = boxes_filt.cpu()
101
+ transformed_boxes = predictor.transform.apply_boxes_torch(
102
+ boxes_filt, image.shape[:2]).to(device)
103
 
104
  masks, _, _ = predictor.predict_torch(
105
+ point_coords=None,
106
+ point_labels=None,
107
+ boxes=transformed_boxes.to(device),
108
+ multimask_output=False,
109
  )
110
 
111
+ # Create mask image
112
+ value = 0 # 0 for background
113
+ mask_img = torch.zeros(masks.shape[-2:])
114
+ for idx, mask in enumerate(masks):
115
+ mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
 
 
116
 
117
+ fig = plt.figure(figsize=(10, 10))
118
+ plt.imshow(mask_img.numpy())
119
  plt.axis('off')
 
 
 
 
120
 
121
+ buf = BytesIO()
122
+ plt.savefig(buf, format='png', bbox_inches="tight",
123
+ dpi=300, pad_inches=0.0)
124
+ buf.seek(0)
125
+ out_pil = Image.open(buf)
126
+
127
+ return out_pil
segment_anything/segment_anything.egg-info/PKG-INFO ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: segment-anything
3
+ Version: 1.0
4
+ License-File: LICENSE
5
+ Provides-Extra: all
6
+ Requires-Dist: matplotlib; extra == "all"
7
+ Requires-Dist: pycocotools; extra == "all"
8
+ Requires-Dist: opencv-python; extra == "all"
9
+ Requires-Dist: onnx; extra == "all"
10
+ Requires-Dist: onnxruntime; extra == "all"
11
+ Provides-Extra: dev
12
+ Requires-Dist: flake8; extra == "dev"
13
+ Requires-Dist: isort; extra == "dev"
14
+ Requires-Dist: black; extra == "dev"
15
+ Requires-Dist: mypy; extra == "dev"
segment_anything/segment_anything.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.cfg
4
+ setup.py
5
+ segment_anything/__init__.py
6
+ segment_anything/automatic_mask_generator.py
7
+ segment_anything/build_sam.py
8
+ segment_anything/build_sam_hq.py
9
+ segment_anything/predictor.py
10
+ segment_anything.egg-info/PKG-INFO
11
+ segment_anything.egg-info/SOURCES.txt
12
+ segment_anything.egg-info/dependency_links.txt
13
+ segment_anything.egg-info/requires.txt
14
+ segment_anything.egg-info/top_level.txt
15
+ segment_anything/modeling/__init__.py
16
+ segment_anything/modeling/common.py
17
+ segment_anything/modeling/image_encoder.py
18
+ segment_anything/modeling/mask_decoder.py
19
+ segment_anything/modeling/mask_decoder_hq.py
20
+ segment_anything/modeling/prompt_encoder.py
21
+ segment_anything/modeling/sam.py
22
+ segment_anything/modeling/transformer.py
23
+ segment_anything/utils/__init__.py
24
+ segment_anything/utils/amg.py
25
+ segment_anything/utils/onnx.py
26
+ segment_anything/utils/transforms.py
segment_anything/segment_anything.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
segment_anything/segment_anything.egg-info/requires.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [all]
3
+ matplotlib
4
+ pycocotools
5
+ opencv-python
6
+ onnx
7
+ onnxruntime
8
+
9
+ [dev]
10
+ flake8
11
+ isort
12
+ black
13
+ mypy
segment_anything/segment_anything.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ segment_anything
segment_anything/segment_anything/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (572 Bytes). View file
 
segment_anything/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
segment_anything/segment_anything/__pycache__/build_sam.cpython-310.pyc ADDED
Binary file (2.16 kB). View file
 
segment_anything/segment_anything/__pycache__/build_sam_hq.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
segment_anything/segment_anything/__pycache__/predictor.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
segment_anything/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (458 Bytes). View file
 
segment_anything/segment_anything/modeling/__pycache__/common.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
segment_anything/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
segment_anything/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc ADDED
Binary file (5.54 kB). View file
 
segment_anything/segment_anything/modeling/__pycache__/mask_decoder_hq.cpython-310.pyc ADDED
Binary file (6.62 kB). View file
 
segment_anything/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc ADDED
Binary file (7.69 kB). View file
 
segment_anything/segment_anything/modeling/__pycache__/sam.cpython-310.pyc ADDED
Binary file (6.67 kB). View file
 
segment_anything/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (6.61 kB). View file
 
segment_anything/segment_anything/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (165 Bytes). View file
 
segment_anything/segment_anything/utils/__pycache__/amg.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
segment_anything/segment_anything/utils/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.94 kB). View file
 
test.py CHANGED
@@ -4,33 +4,37 @@ import torch
4
  import requests
5
  from PIL import Image
6
  from io import BytesIO
 
 
 
 
 
7
 
8
- is_production = True
9
 
10
  os.chdir("/repository" if is_production else ".")
11
  os.environ['AM_I_DOCKER'] = 'False'
12
  os.environ['BUILD_WITH_CUDA'] = 'True'
13
  os.environ['CUDA_HOME'] = '/usr/local/cuda-11.7/' if is_production else '/usr/local/cuda-12.1/'
14
 
15
- # Install Segment Anything
16
- subprocess.run(["python", "-m", "pip", "install", "-e", "segment_anything"])
17
 
18
- # Install Grounding DINO
19
- subprocess.run(["python", "-m", "pip", "install", "-e", "GroundingDINO"])
20
 
21
- subprocess.run("wget https://huggingface.co/Uminosachi/sam-hq/resolve/main/sam_hq_vit_h.pth -O ./sam_hq_vit_h.pth", shell=True)
 
 
22
 
23
- # Install diffusers
24
- subprocess.run(["pip", "install", "--upgrade", "diffusers[torch]"])
25
 
26
- # Install osx
27
  subprocess.run(["git", "submodule", "update", "--init", "--recursive"])
28
  subprocess.run(["bash", "grounded-sam-osx/install.sh"], cwd="grounded-sam-osx")
29
 
30
- # Install RAM & Tag2Text
31
  subprocess.run(["git", "clone", "https://github.com/xinyu1205/recognize-anything.git"])
32
- subprocess.run(["pip", "install", "-r", "./recognize-anything/requirements.txt"])
33
- subprocess.run(["pip", "install", "-e", "./recognize-anything/"])
 
 
34
 
35
  from grounded_sam_demo import grounded_sam_demo
36
  import numpy as np
 
4
  import requests
5
  from PIL import Image
6
  from io import BytesIO
7
+ import subprocess
8
+ import sys
9
+
10
+ def pip_command(command):
11
+ subprocess.check_call([sys.executable, "-m", "pip"] + command.split())
12
 
13
+ is_production = False
14
 
15
  os.chdir("/repository" if is_production else ".")
16
  os.environ['AM_I_DOCKER'] = 'False'
17
  os.environ['BUILD_WITH_CUDA'] = 'True'
18
  os.environ['CUDA_HOME'] = '/usr/local/cuda-11.7/' if is_production else '/usr/local/cuda-12.1/'
19
 
20
+ pip_command("install -e segment_anything")
 
21
 
22
+ pip_command("install -e GroundingDINO")
 
23
 
24
+ response = requests.get("https://huggingface.co/Uminosachi/sam-hq/resolve/main/sam_hq_vit_h.pth")
25
+ with open('./sam_hq_vit_h.pth', 'wb') as file:
26
+ file.write(response.content)
27
 
28
+ pip_command("install --upgrade diffusers[torch]")
 
29
 
 
30
  subprocess.run(["git", "submodule", "update", "--init", "--recursive"])
31
  subprocess.run(["bash", "grounded-sam-osx/install.sh"], cwd="grounded-sam-osx")
32
 
 
33
  subprocess.run(["git", "clone", "https://github.com/xinyu1205/recognize-anything.git"])
34
+
35
+ pip_command("install -r ./recognize-anything/requirements.txt")
36
+
37
+ pip_command("install -e ./recognize-anything/")
38
 
39
  from grounded_sam_demo import grounded_sam_demo
40
  import numpy as np