kushagra124 commited on
Commit
c2e6eeb
1 Parent(s): ff9f53e

adding app with CLIP image segmentation

Browse files
Files changed (4) hide show
  1. app.py +26 -38
  2. images/rooom2.jpg +0 -0
  3. images/seats.jpg +0 -0
  4. images/vegetables.jpg +0 -0
app.py CHANGED
@@ -13,25 +13,16 @@ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
13
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  classes = list()
15
 
16
- def create_mask(image,image_mask,alpha=0.7):
17
- mask = np.zeros_like(image)
18
- # copy your image_mask to all dimensions (i.e. colors) of your image
19
- for i in range(3):
20
- mask[:,:,i] = image_mask.copy()
21
- # apply the mask to your image
22
- overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
23
- return overlay_image
24
 
25
- def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
26
- bbox = np.asarray(bbox)/model_shape
27
- y1,y2 = bbox[::2] *orig_image_shape[0]
28
- x1,x2 = bbox[1::2]*orig_image_shape[1]
29
- return [int(y1),int(x1),int(y2),int(x2)]
30
 
31
  def detect_using_clip(image,prompts=[],threshould=0.4):
32
  h,w = image.shape[:2]
33
- model_detections = dict()
34
- predicted_images = dict()
35
  inputs = processor(
36
  text=prompts,
37
  images=[image] * len(prompts),
@@ -42,31 +33,25 @@ def detect_using_clip(image,prompts=[],threshould=0.4):
42
  outputs = model(**inputs)
43
  preds = outputs.logits.unsqueeze(1)
44
 
45
- detection = outputs.logits[0] # Assuming class index 0
46
  for i,prompt in enumerate(prompts):
47
  predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
48
- predicted_image = np.where(predicted_image>threshould,np.random.randint(128,255),0)
49
- # extract countours from the image
50
- lbl_0 = label(predicted_image)
51
- props = regionprops(lbl_0)
52
- prompt = prompt.lower()
53
-
54
- model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
55
- predicted_images[prompt]= predicted_image
56
- return model_detections , predicted_images
57
 
58
- def visualize_images(image,detections,predicted_images,prompt):
 
 
 
59
  alpha = 0.7
60
  # H,W = image.shape[:2]
61
  prompt = prompt.lower()
62
  image_resize = cv2.resize(image,(352,352))
63
- mask_image = create_mask(image=image_resize,image_mask=predicted_images[prompt])
 
 
 
 
 
64
 
65
- if prompt not in detections.keys():
66
- print("prompt not in query ..")
67
- return image_resize
68
- final_image = cv2.addWeighted(image_resize,alpha,mask_image,1-alpha,0)
69
- return final_image
70
 
71
  def shot(image, labels_text,selected_categoty):
72
  if "," in labels_text:
@@ -74,20 +59,23 @@ def shot(image, labels_text,selected_categoty):
74
  else:
75
  prompts = [labels_text]
76
  prompts = list(map(lambda x: x.strip(),prompts))
77
- model_detections,predicted_images = detect_using_clip(image,prompts=prompts)
78
-
79
- category_image = visualize_images(image=image,detections=model_detections,predicted_images=predicted_images,prompt=selected_categoty)
80
 
 
 
 
81
  return category_image
82
 
83
  iface = gr.Interface(fn=shot,
84
- inputs = ["image","text","text"],
85
  outputs = "image",
86
  description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )",
87
  title = "Zero-shot Image Segmentation with Prompt ",
88
  examples=[
89
- ["images/room.jpg","bed, table, plant, light, window",'plant'],
90
- ["images/image2.png","banner, building,door, sign","sign"]
 
 
 
91
  ],
92
  # allow_flagging=False,
93
  # analytics_enabled=False,
 
13
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  classes = list()
15
 
16
+ def create_rgb_mask(mask):
17
+ color = tuple(np.random.choice(range(0,256), size=3))
18
+ gray_3_channel = cv2.merge((mask, mask, mask))
19
+ gray_3_channel[mask==255] = color
20
+ return gray_3_channel.astype(np.uint8)
 
 
 
21
 
 
 
 
 
 
22
 
23
  def detect_using_clip(image,prompts=[],threshould=0.4):
24
  h,w = image.shape[:2]
25
+ predicted_masks = list()
 
26
  inputs = processor(
27
  text=prompts,
28
  images=[image] * len(prompts),
 
33
  outputs = model(**inputs)
34
  preds = outputs.logits.unsqueeze(1)
35
 
 
36
  for i,prompt in enumerate(prompts):
37
  predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
38
+ predicted_image = np.where(predicted_image>threshould,255,0)
 
 
 
 
 
 
 
 
39
 
40
+ predicted_masks.append(create_rgb_mask(predicted_image))
41
+ return predicted_masks
42
+
43
+ def visualize_images(image,predicted_images):
44
  alpha = 0.7
45
  # H,W = image.shape[:2]
46
  prompt = prompt.lower()
47
  image_resize = cv2.resize(image,(352,352))
48
+ resize_image_copy = image_resize.copy()
49
+
50
+ for mask_image in predicted_images:
51
+ resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)
52
+
53
+ return cv2.convertScaleAbs(resize_image_copy, alpha=1.8, beta=15)
54
 
 
 
 
 
 
55
 
56
  def shot(image, labels_text,selected_categoty):
57
  if "," in labels_text:
 
59
  else:
60
  prompts = [labels_text]
61
  prompts = list(map(lambda x: x.strip(),prompts))
 
 
 
62
 
63
+ predicted_images = detect_using_clip(image,prompts=prompts)
64
+
65
+ category_image = visualize_images(image=image,predicted_images=predicted_images)
66
  return category_image
67
 
68
  iface = gr.Interface(fn=shot,
69
+ inputs = ["image","text"],
70
  outputs = "image",
71
  description ="Add an Image and lists of category to be detected separated by commas(atleast 2 )",
72
  title = "Zero-shot Image Segmentation with Prompt ",
73
  examples=[
74
+ ["images/room.jpg","bed, table, plant, light, window,light"],
75
+ ["images/image2.png","banner, building,door, sign,"],
76
+ ["images/seats.jpg","door,table,chairs"],
77
+ ["images/vegetables.jpg","carrot,radish,beans,potato,brnjal,basket"]
78
+ ["images/room2.jpg","door,platns,dog,coffe table,mug,pillow,table lamp,carpet,pictures,door,clock"]
79
  ],
80
  # allow_flagging=False,
81
  # analytics_enabled=False,
images/rooom2.jpg ADDED
images/seats.jpg ADDED
images/vegetables.jpg ADDED