SkalskiP commited on
Commit
f89aac1
1 Parent(s): df57751

Update app to support confidence adjustment and mask options

Browse files

Updated the app.py and requirements.txt to introduce a slide bar for adjusting confidence threshold and adding a solid mask annotator in addition to the semi-transparent mask annotator. Also, modifications have been made to include the confidence parameter in the inference function call. The gradio version has been pinned to 3.50.2 in requirements.txt for consistency. These changes were made to offer users a better interaction experience and greater control over the image processing results.

Files changed (2) hide show
  1. app.py +44 -19
  2. requirements.txt +1 -1
app.py CHANGED
@@ -14,11 +14,12 @@ This is the demo for a Open Vocabulary Image Segmentation using
14
  [MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo.
15
  """
16
  EXAMPLES = [
17
- ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog"],
18
- ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building"],
19
- ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket"],
 
20
  ]
21
-
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  SAM_GENERATOR = pipeline(
24
  task="mask-generation",
@@ -26,9 +27,13 @@ SAM_GENERATOR = pipeline(
26
  device=DEVICE)
27
  CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
28
  CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
29
- MASK_ANNOTATOR = sv.MaskAnnotator(
30
  color=sv.Color.red(),
31
  color_lookup=sv.ColorLookup.INDEX)
 
 
 
 
32
 
33
 
34
  def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
@@ -54,9 +59,13 @@ def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
54
  return np.where(mask[..., None], image, gray_color)
55
 
56
 
57
- def annotate(image_rgb_pil: Image.Image, detections: sv.Detections) -> Image.Image:
 
 
 
 
58
  img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
59
- annotated_bgr_image = MASK_ANNOTATOR.annotate(
60
  scene=img_bgr_numpy, detections=detections)
61
  return Image.fromarray(annotated_bgr_image[:, :, ::-1])
62
 
@@ -64,7 +73,8 @@ def annotate(image_rgb_pil: Image.Image, detections: sv.Detections) -> Image.Ima
64
  def filter_detections(
65
  image_rgb_pil: Image.Image,
66
  detections: sv.Detections,
67
- prompt: str
 
68
  ) -> sv.Detections:
69
  img_rgb_numpy = np.array(image_rgb_pil)
70
  text = [f"a picture of {prompt}", "a picture of background"]
@@ -76,27 +86,38 @@ def filter_detections(
76
  masked_crop = reverse_mask_image(image=crop, mask=mask_crop)
77
  masked_crop_pil = Image.fromarray(masked_crop)
78
  probs = run_clip(image_rgb_pil=masked_crop_pil, text=text)
79
- lass_index = np.argmax(probs)
80
- filtering_mask.append(lass_index == 0)
81
 
82
  filtering_mask = np.array(filtering_mask)
83
  return detections[filtering_mask]
84
 
85
 
86
- def inference(image_rgb_pil: Image.Image, prompt: str) -> List[Image.Image]:
 
 
 
 
87
  width, height = image_rgb_pil.size
88
  area = width * height
89
 
90
  detections = run_sam(image_rgb_pil)
91
- detections = detections[detections.area / area > 0.01]
92
  detections = filter_detections(
93
  image_rgb_pil=image_rgb_pil,
94
  detections=detections,
95
- prompt=prompt)
 
96
 
 
97
  return [
98
- annotate(image_rgb_pil=image_rgb_pil, detections=detections),
99
- annotate(image_rgb_pil=Image.new("RGB", (width, height), "black"), detections=detections)
 
 
 
 
 
 
100
  ]
101
 
102
 
@@ -104,15 +125,19 @@ with gr.Blocks() as demo:
104
  gr.Markdown(MARKDOWN)
105
  with gr.Row():
106
  with gr.Column():
107
- input_image = gr.Image(image_mode='RGB', type='pil', height=500)
108
- prompt_text = gr.Textbox(label="Prompt", value="dog")
 
 
 
 
109
  submit_button = gr.Button("Submit")
110
  gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)
111
  with gr.Row():
112
  gr.Examples(
113
  examples=EXAMPLES,
114
  fn=inference,
115
- inputs=[input_image, prompt_text],
116
  outputs=[gallery],
117
  cache_examples=True,
118
  run_on_click=True
@@ -120,7 +145,7 @@ with gr.Blocks() as demo:
120
 
121
  submit_button.click(
122
  inference,
123
- inputs=[input_image, prompt_text],
124
  outputs=gallery)
125
 
126
  demo.launch(debug=False, show_error=True)
 
14
  [MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo.
15
  """
16
  EXAMPLES = [
17
+ ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
18
+ ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5],
19
+ ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5],
20
+ ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6],
21
  ]
22
+ MIN_AREA_THRESHOLD = 0.01
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  SAM_GENERATOR = pipeline(
25
  task="mask-generation",
 
27
  device=DEVICE)
28
  CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
29
  CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
30
+ SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
31
  color=sv.Color.red(),
32
  color_lookup=sv.ColorLookup.INDEX)
33
+ SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
34
+ color=sv.Color.red(),
35
+ color_lookup=sv.ColorLookup.INDEX,
36
+ opacity=1)
37
 
38
 
39
  def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
 
59
  return np.where(mask[..., None], image, gray_color)
60
 
61
 
62
+ def annotate(
63
+ image_rgb_pil: Image.Image,
64
+ detections: sv.Detections,
65
+ annotator: sv.MaskAnnotator
66
+ ) -> Image.Image:
67
  img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
68
+ annotated_bgr_image = annotator.annotate(
69
  scene=img_bgr_numpy, detections=detections)
70
  return Image.fromarray(annotated_bgr_image[:, :, ::-1])
71
 
 
73
  def filter_detections(
74
  image_rgb_pil: Image.Image,
75
  detections: sv.Detections,
76
+ prompt: str,
77
+ confidence: float
78
  ) -> sv.Detections:
79
  img_rgb_numpy = np.array(image_rgb_pil)
80
  text = [f"a picture of {prompt}", "a picture of background"]
 
86
  masked_crop = reverse_mask_image(image=crop, mask=mask_crop)
87
  masked_crop_pil = Image.fromarray(masked_crop)
88
  probs = run_clip(image_rgb_pil=masked_crop_pil, text=text)
89
+ filtering_mask.append(probs[0][0] > confidence)
 
90
 
91
  filtering_mask = np.array(filtering_mask)
92
  return detections[filtering_mask]
93
 
94
 
95
+ def inference(
96
+ image_rgb_pil: Image.Image,
97
+ prompt: str,
98
+ confidence: float
99
+ ) -> List[Image.Image]:
100
  width, height = image_rgb_pil.size
101
  area = width * height
102
 
103
  detections = run_sam(image_rgb_pil)
104
+ detections = detections[detections.area / area > MIN_AREA_THRESHOLD]
105
  detections = filter_detections(
106
  image_rgb_pil=image_rgb_pil,
107
  detections=detections,
108
+ prompt=prompt,
109
+ confidence=confidence)
110
 
111
+ blank_image = Image.new("RGB", (width, height), "black")
112
  return [
113
+ annotate(
114
+ image_rgb_pil=image_rgb_pil,
115
+ detections=detections,
116
+ annotator=SEMITRANSPARENT_MASK_ANNOTATOR),
117
+ annotate(
118
+ image_rgb_pil=blank_image,
119
+ detections=detections,
120
+ annotator=SOLID_MASK_ANNOTATOR)
121
  ]
122
 
123
 
 
125
  gr.Markdown(MARKDOWN)
126
  with gr.Row():
127
  with gr.Column():
128
+ input_image = gr.Image(
129
+ image_mode='RGB', type='pil', height=500)
130
+ prompt_text = gr.Textbox(
131
+ label="Prompt", value="dog")
132
+ confidence_slider = gr.Slider(
133
+ label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6)
134
  submit_button = gr.Button("Submit")
135
  gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)
136
  with gr.Row():
137
  gr.Examples(
138
  examples=EXAMPLES,
139
  fn=inference,
140
+ inputs=[input_image, prompt_text, confidence_slider],
141
  outputs=[gallery],
142
  cache_examples=True,
143
  run_on_click=True
 
145
 
146
  submit_button.click(
147
  inference,
148
+ inputs=[input_image, prompt_text, confidence_slider],
149
  outputs=gallery)
150
 
151
  demo.launch(debug=False, show_error=True)
requirements.txt CHANGED
@@ -4,6 +4,6 @@ torchvision
4
 
5
  numpy
6
  pillow
7
- gradio
8
  transformers
9
  supervision
 
4
 
5
  numpy
6
  pillow
7
+ gradio==gradio==3.50.2
8
  transformers
9
  supervision