Spaces:
Running
on
T4
Running
on
T4
liuyizhang
commited on
Commit
•
c419c35
1
Parent(s):
e5f7fa3
update app.py
Browse files
app.py
CHANGED
@@ -116,18 +116,16 @@ def load_image(image_path):
|
|
116 |
image, _ = transform(image_pil, None) # 3, h, w
|
117 |
return image_pil, image
|
118 |
|
119 |
-
|
120 |
def load_model(model_config_path, model_checkpoint_path, device):
|
121 |
args = SLConfig.fromfile(model_config_path)
|
122 |
args.device = device
|
123 |
model = build_model(args)
|
124 |
-
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
125 |
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
126 |
print(load_res)
|
127 |
_ = model.eval()
|
128 |
return model
|
129 |
|
130 |
-
|
131 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
132 |
caption = caption.lower()
|
133 |
caption = caption.strip()
|
@@ -172,14 +170,12 @@ def show_mask(mask, ax, random_color=False):
|
|
172 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
173 |
ax.imshow(mask_image)
|
174 |
|
175 |
-
|
176 |
def show_box(box, ax, label):
|
177 |
x0, y0 = box[0], box[1]
|
178 |
w, h = box[2] - box[0], box[3] - box[1]
|
179 |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
180 |
ax.text(x0, y0, label)
|
181 |
|
182 |
-
|
183 |
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
184 |
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
185 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
@@ -189,6 +185,19 @@ device = "cuda"
|
|
189 |
|
190 |
device = get_device()
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
|
193 |
assert text_prompt, 'text_prompt is not found!'
|
194 |
|
@@ -196,24 +205,20 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
|
|
196 |
os.makedirs(output_dir, exist_ok=True)
|
197 |
# load image
|
198 |
image_pil, image = load_image(image_path.convert("RGB"))
|
199 |
-
# load model
|
200 |
-
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
201 |
|
202 |
# visualize raw image
|
203 |
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
204 |
|
205 |
# run grounding dino model
|
206 |
boxes_filt, pred_phrases = get_grounding_output(
|
207 |
-
|
208 |
)
|
209 |
|
210 |
size = image_pil.size
|
211 |
|
212 |
if task_type == 'segment' or task_type == 'inpainting':
|
213 |
-
# initialize SAM
|
214 |
-
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
215 |
image = np.array(image_path)
|
216 |
-
|
217 |
|
218 |
H, W = size[1], size[0]
|
219 |
for i in range(boxes_filt.size(0)):
|
@@ -222,9 +227,9 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
|
|
222 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
223 |
|
224 |
boxes_filt = boxes_filt.cpu()
|
225 |
-
transformed_boxes =
|
226 |
|
227 |
-
masks, _, _ =
|
228 |
point_coords = None,
|
229 |
point_labels = None,
|
230 |
boxes = transformed_boxes,
|
@@ -266,14 +271,8 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
|
|
266 |
mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
|
267 |
mask_pil = Image.fromarray(mask)
|
268 |
image_pil = Image.fromarray(image)
|
269 |
-
|
270 |
-
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
271 |
-
"runwayml/stable-diffusion-inpainting",
|
272 |
-
# torch_dtype=torch.float16
|
273 |
-
)
|
274 |
-
pipe = pipe.to(device)
|
275 |
|
276 |
-
image =
|
277 |
image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
|
278 |
image.save(image_path)
|
279 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
|
|
116 |
image, _ = transform(image_pil, None) # 3, h, w
|
117 |
return image_pil, image
|
118 |
|
|
|
119 |
def load_model(model_config_path, model_checkpoint_path, device):
|
120 |
args = SLConfig.fromfile(model_config_path)
|
121 |
args.device = device
|
122 |
model = build_model(args)
|
123 |
+
checkpoint = torch.load(model_checkpoint_path, map_location=device) #"cpu")
|
124 |
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
125 |
print(load_res)
|
126 |
_ = model.eval()
|
127 |
return model
|
128 |
|
|
|
129 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
130 |
caption = caption.lower()
|
131 |
caption = caption.strip()
|
|
|
170 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
171 |
ax.imshow(mask_image)
|
172 |
|
|
|
173 |
def show_box(box, ax, label):
|
174 |
x0, y0 = box[0], box[1]
|
175 |
w, h = box[2] - box[0], box[3] - box[1]
|
176 |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
177 |
ax.text(x0, y0, label)
|
178 |
|
|
|
179 |
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
180 |
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
181 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
|
|
185 |
|
186 |
device = get_device()
|
187 |
|
188 |
+
# initialize groundingdino model
|
189 |
+
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
190 |
+
|
191 |
+
# initialize SAM
|
192 |
+
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
193 |
+
|
194 |
+
# initialize stable-diffusion-inpainting
|
195 |
+
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
196 |
+
"runwayml/stable-diffusion-inpainting",
|
197 |
+
# torch_dtype=torch.float16
|
198 |
+
)
|
199 |
+
sd_pipe = sd_pipe.to(device)
|
200 |
+
|
201 |
def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
|
202 |
assert text_prompt, 'text_prompt is not found!'
|
203 |
|
|
|
205 |
os.makedirs(output_dir, exist_ok=True)
|
206 |
# load image
|
207 |
image_pil, image = load_image(image_path.convert("RGB"))
|
|
|
|
|
208 |
|
209 |
# visualize raw image
|
210 |
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
211 |
|
212 |
# run grounding dino model
|
213 |
boxes_filt, pred_phrases = get_grounding_output(
|
214 |
+
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=device
|
215 |
)
|
216 |
|
217 |
size = image_pil.size
|
218 |
|
219 |
if task_type == 'segment' or task_type == 'inpainting':
|
|
|
|
|
220 |
image = np.array(image_path)
|
221 |
+
sam_predictor.set_image(image)
|
222 |
|
223 |
H, W = size[1], size[0]
|
224 |
for i in range(boxes_filt.size(0)):
|
|
|
227 |
boxes_filt[i][2:] += boxes_filt[i][:2]
|
228 |
|
229 |
boxes_filt = boxes_filt.cpu()
|
230 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
|
231 |
|
232 |
+
masks, _, _ = sam_predictor.predict_torch(
|
233 |
point_coords = None,
|
234 |
point_labels = None,
|
235 |
boxes = transformed_boxes,
|
|
|
271 |
mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
|
272 |
mask_pil = Image.fromarray(mask)
|
273 |
image_pil = Image.fromarray(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
+
image = sd_pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
|
276 |
image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
|
277 |
image.save(image_path)
|
278 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|