mahmudunnabi commited on
Commit
1db4907
1 Parent(s): 7f1929a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the dependencies
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import SamModel, SamProcessor
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ # Load the SAM model and processor
11
+ model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
12
+ processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
13
+
14
+
15
+ # Global variable to store input points
16
+ input_points = []
17
+
18
+ # Helper functions
19
+ def show_mask(mask, ax, random_color=False):
20
+ if random_color:
21
+ color = np.concatenate([np.random.random(3),
22
+ np.array([0.6])],
23
+ axis=0)
24
+ else:
25
+ color = np.array([30/255, 144/255, 255/255, 0.6])
26
+ h, w = mask.shape[-2:]
27
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
28
+ ax.imshow(mask_image)
29
+ # Function to get pixel coordinates
30
+ def get_pixel_coordinates(image, evt: gr.SelectData):
31
+ global input_points
32
+ x, y = evt.index[0], evt.index[1]
33
+ input_points = [[[x, y]]]
34
+ return perform_prediction(image)
35
+
36
+ # Function to perform SAM model prediction
37
+ def perform_prediction(image):
38
+ global input_points
39
+ # Preprocess the image
40
+ inputs = processor(images=image, input_points=input_points, return_tensors="pt")
41
+ # Perform prediction
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ iou = outputs.iou_scores
45
+ max_iou_index = torch.argmax(iou)
46
+
47
+ # Post-process the masks
48
+ predicted_masks = processor.image_processor.post_process_masks(
49
+ outputs.pred_masks,
50
+ inputs['original_sizes'],
51
+ inputs['reshaped_input_sizes']
52
+ )
53
+ predicted_mask = predicted_masks[0]
54
+
55
+ # Display the mask on the image
56
+ mask_image = show_mask_on_image(image, predicted_mask[:,max_iou_index], return_image=True)
57
+ return mask_image
58
+
59
+ # Function to overlay mask on the image
60
+ def show_mask_on_image(raw_image, mask, return_image=False):
61
+ if not isinstance(mask, torch.Tensor):
62
+ mask = torch.Tensor(mask)
63
+
64
+ if len(mask.shape) == 4:
65
+ mask = mask.squeeze()
66
+
67
+ fig, axes = plt.subplots(1, 1, figsize=(15, 15))
68
+
69
+ mask = mask.cpu().detach()
70
+ axes.imshow(np.array(raw_image))
71
+ show_mask(mask, axes)
72
+ axes.axis("off")
73
+ plt.show()
74
+
75
+ if return_image:
76
+ fig = plt.gcf()
77
+ fig.canvas.draw()
78
+ # Convert plot to image
79
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
80
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
81
+ img = Image.fromarray(img)
82
+ plt.close(fig)
83
+ return img
84
+
85
+
86
+
87
+ # Create the Gradio interface
88
+ with gr.Blocks() as demo:
89
+ gr.Markdown(
90
+ """
91
+ <div style='text-align: center; font-family: "Times New Roman";'>
92
+ <h1 style='color: #FF6347;'>One Click Image Segmentation App</h1>
93
+ <h3 style='color: #4682B4;'>Model: SlimSAM-uniform-77</h3>
94
+ <h3 style='color: #32CD32;'>Made By: Md. Mahmudun Nabi</h3>
95
+ </div>
96
+ """
97
+ )
98
+ with gr.Row():
99
+
100
+ img = gr.Image(type="pil", label="Input Image",height=400, width=600)
101
+ output_image = gr.Image(label="Masked Image")
102
+
103
+ img.select(get_pixel_coordinates, inputs=[img], outputs=[output_image])
104
+
105
+
106
+ if __name__ == "__main__":
107
+ demo.launch(share=False)