Jun Xiong commited on
Commit
1bd7ddc
1 Parent(s): a7df02b
Files changed (6) hide show
  1. README.md +6 -6
  2. index.css +116 -0
  3. index.html +24 -13
  4. index.js +268 -52
  5. style.css +18 -66
  6. worker.js +109 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Segment Ui
3
- emoji: 🌐
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: static
7
  pinned: false
8
  models:
9
- - Xenova/detr-resnet-50
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Segment Anything Web
3
+ emoji: 💻
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: static
7
  pinned: false
8
  models:
9
+ - Xenova/slimsam-77-uniform
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
index.css ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ padding: 0;
4
+ margin: 0;
5
+ font-family: sans-serif;
6
+ }
7
+
8
+ html,
9
+ body {
10
+ height: 100%;
11
+ }
12
+
13
+ body {
14
+ padding: 16px 32px;
15
+ }
16
+
17
+ body,
18
+ #container,
19
+ #upload-button {
20
+ display: flex;
21
+ flex-direction: column;
22
+ justify-content: center;
23
+ align-items: center;
24
+ }
25
+
26
+ h1 {
27
+ text-align: center;
28
+ }
29
+
30
+ #container {
31
+ position: relative;
32
+ width: 640px;
33
+ height: 420px;
34
+ max-width: 100%;
35
+ max-height: 100%;
36
+ border: 2px dashed #D1D5DB;
37
+ border-radius: 0.75rem;
38
+ overflow: hidden;
39
+ cursor: pointer;
40
+ margin-top: 1rem;
41
+ background-size: 100% 100%;
42
+ background-position: center;
43
+ background-repeat: no-repeat;
44
+ }
45
+
46
+ #mask-output {
47
+ position: absolute;
48
+ width: 100%;
49
+ height: 100%;
50
+ pointer-events: none;
51
+ }
52
+
53
+ #upload-button {
54
+ gap: 0.4rem;
55
+ font-size: 18px;
56
+ cursor: pointer;
57
+ }
58
+
59
+ #upload {
60
+ display: none;
61
+ }
62
+
63
+ svg {
64
+ pointer-events: none;
65
+ }
66
+
67
+ #example {
68
+ font-size: 14px;
69
+ text-decoration: underline;
70
+ cursor: pointer;
71
+ }
72
+
73
+ #example:hover {
74
+ color: #2563EB;
75
+ }
76
+
77
+ canvas {
78
+ position: absolute;
79
+ width: 100%;
80
+ height: 100%;
81
+ opacity: 0.6;
82
+ }
83
+
84
+ #status {
85
+ min-height: 16px;
86
+ margin: 8px 0;
87
+ }
88
+
89
+ .icon {
90
+ height: 16px;
91
+ width: 16px;
92
+ position: absolute;
93
+ transform: translate(-50%, -50%);
94
+ }
95
+
96
+ #controls>button {
97
+ padding: 6px 12px;
98
+ background-color: #3498db;
99
+ color: white;
100
+ border: 1px solid #2980b9;
101
+ border-radius: 5px;
102
+ cursor: pointer;
103
+ font-size: 16px;
104
+ }
105
+
106
+ #controls>button:disabled {
107
+ background-color: #d1d5db;
108
+ color: #6b7280;
109
+ border: 1px solid #9ca3af;
110
+ cursor: not-allowed;
111
+ }
112
+
113
+ #information {
114
+ margin-top: 0.25rem;
115
+ font-size: 15px;
116
+ }
index.html CHANGED
@@ -3,24 +3,35 @@
3
 
4
  <head>
5
  <meta charset="UTF-8" />
6
- <link rel="stylesheet" href="style.css" />
7
 
8
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
- <title>Transformers.js - Object Detection</title>
10
  </head>
11
 
12
  <body>
13
- <h1>Object Detection w/ 🤗 Transformers.js</h1>
14
- <label id="container" for="upload">
15
- <svg width="25" height="25" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
16
- <path fill="#000"
17
- d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z">
18
- </path>
19
- </svg>
20
- Click to upload image
21
- <label id="example">(or try example)</label>
22
- </label>
23
- <label id="status">Loading model...</label>
 
 
 
 
 
 
 
 
 
 
 
24
  <input id="upload" type="file" accept="image/*" />
25
 
26
  <script src="index.js" type="module"></script>
 
3
 
4
  <head>
5
  <meta charset="UTF-8" />
6
+ <link rel="stylesheet" href="index.css" />
7
 
8
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+ <title>Transformers.js - Segment Anything</title>
10
  </head>
11
 
12
  <body>
13
+ <h1>Segment Anything w/ 🤗 Transformers.js</h1>
14
+ <div id="container">
15
+ <label id="upload-button" for="upload">
16
+ <svg width="25" height="25" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
17
+ <path fill="#000"
18
+ d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z">
19
+ </path>
20
+ </svg>
21
+ Click to upload image
22
+ <label id="example">(or try example)</label>
23
+ </label>
24
+ <canvas id="mask-output"></canvas>
25
+ </div>
26
+ <label id="status"></label>
27
+ <div id="controls">
28
+ <button id="reset-image">Reset image</button>
29
+ <button id="clear-points">Clear points</button>
30
+ <button id="cut-mask" disabled>Cut mask</button>
31
+ </div>
32
+ <p id="information">
33
+ Left click = positive points, right click = negative points.
34
+ </p>
35
  <input id="upload" type="file" accept="image/*" />
36
 
37
  <script src="index.js" type="module"></script>
index.js CHANGED
@@ -1,26 +1,165 @@
1
- import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/[email protected]';
2
 
3
- // Since we will download the model from the Hugging Face Hub, we can skip the local model check
4
- env.allowLocalModels = false;
5
-
6
- // Reference the elements that we will need
7
- const status = document.getElementById('status');
8
  const fileUpload = document.getElementById('upload');
9
  const imageContainer = document.getElementById('container');
10
  const example = document.getElementById('example');
 
 
 
 
 
11
 
12
- const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg';
 
 
 
 
 
 
13
 
14
- // Create a new object detection pipeline
15
- status.textContent = 'Loading model...';
16
- const detector = await pipeline('object-detection', 'Xenova/detr-resnet-50');
17
- status.textContent = 'Ready';
18
 
19
- example.addEventListener('click', (e) => {
20
- e.preventDefault();
21
- detect(EXAMPLE_URL);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  });
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  fileUpload.addEventListener('change', function (e) {
25
  const file = e.target.files[0];
26
  if (!file) {
@@ -30,50 +169,127 @@ fileUpload.addEventListener('change', function (e) {
30
  const reader = new FileReader();
31
 
32
  // Set up a callback when the file is loaded
33
- reader.onload = e2 => detect(e2.target.result);
34
 
35
  reader.readAsDataURL(file);
36
  });
37
 
 
 
 
 
38
 
39
- // Detect objects in the image
40
- async function detect(img) {
41
- imageContainer.innerHTML = '';
42
- imageContainer.style.backgroundImage = `url(${img})`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- status.textContent = 'Analysing...';
45
- const output = await detector(img, {
46
- threshold: 0.5,
47
- percentage: true,
48
- });
49
- status.textContent = '';
50
- output.forEach(renderBox);
 
 
 
 
 
 
51
  }
52
 
53
- // Render a bounding box and label on the image
54
- function renderBox({ box, label }) {
55
- const { xmax, xmin, ymax, ymin } = box;
56
-
57
- // Generate a random color for the box
58
- const color = '#' + Math.floor(Math.random() * 0xFFFFFF).toString(16).padStart(6, 0);
59
-
60
- // Draw the box
61
- const boxElement = document.createElement('div');
62
- boxElement.className = 'bounding-box';
63
- Object.assign(boxElement.style, {
64
- borderColor: color,
65
- left: 100 * xmin + '%',
66
- top: 100 * ymin + '%',
67
- width: 100 * (xmax - xmin) + '%',
68
- height: 100 * (ymax - ymin) + '%',
69
- })
70
-
71
- // Draw label
72
- const labelElement = document.createElement('span');
73
- labelElement.textContent = label;
74
- labelElement.className = 'bounding-box-label';
75
- labelElement.style.backgroundColor = color;
76
-
77
- boxElement.appendChild(labelElement);
78
- imageContainer.appendChild(boxElement);
79
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ // Reference the elements we will use
3
+ const statusLabel = document.getElementById('status');
 
 
 
4
  const fileUpload = document.getElementById('upload');
5
  const imageContainer = document.getElementById('container');
6
  const example = document.getElementById('example');
7
+ const maskCanvas = document.getElementById('mask-output');
8
+ const uploadButton = document.getElementById('upload-button');
9
+ const resetButton = document.getElementById('reset-image');
10
+ const clearButton = document.getElementById('clear-points');
11
+ const cutButton = document.getElementById('cut-mask');
12
 
13
+ // State variables
14
+ let lastPoints = null;
15
+ let isEncoded = false;
16
+ let isDecoding = false;
17
+ let isMultiMaskMode = false;
18
+ let modelReady = false;
19
+ let imageDataURI = null;
20
 
21
+ // Constants
22
+ const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/';
23
+ const EXAMPLE_URL = BASE_URL + 'corgi.jpg';
 
24
 
25
+ // Create a web worker so that the main (UI) thread is not blocked during inference.
26
+ const worker = new Worker('worker.js', {
27
+ type: 'module',
28
+ });
29
+
30
+ // Preload star and cross images to avoid lag on first click
31
+ const star = new Image();
32
+ star.src = BASE_URL + 'star-icon.png';
33
+ star.className = 'icon';
34
+
35
+ const cross = new Image();
36
+ cross.src = BASE_URL + 'cross-icon.png';
37
+ cross.className = 'icon';
38
+
39
+ // Set up message handler
40
+ worker.addEventListener('message', (e) => {
41
+ const { type, data } = e.data;
42
+ if (type === 'ready') {
43
+ modelReady = true;
44
+ statusLabel.textContent = 'Ready';
45
+
46
+ } else if (type === 'decode_result') {
47
+ isDecoding = false;
48
+
49
+ if (!isEncoded) {
50
+ return; // We are not ready to decode yet
51
+ }
52
+
53
+ if (!isMultiMaskMode && lastPoints) {
54
+ // Perform decoding with the last point
55
+ decode();
56
+ lastPoints = null;
57
+ }
58
+
59
+ const { mask, scores } = data;
60
+
61
+ // Update canvas dimensions (if different)
62
+ if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
63
+ maskCanvas.width = mask.width;
64
+ maskCanvas.height = mask.height;
65
+ }
66
+
67
+ // Create context and allocate buffer for pixel data
68
+ const context = maskCanvas.getContext('2d');
69
+ const imageData = context.createImageData(maskCanvas.width, maskCanvas.height);
70
+
71
+ // Select best mask
72
+ const numMasks = scores.length; // 3
73
+ let bestIndex = 0;
74
+ for (let i = 1; i < numMasks; ++i) {
75
+ if (scores[i] > scores[bestIndex]) {
76
+ bestIndex = i;
77
+ }
78
+ }
79
+ statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
80
+
81
+ // Fill mask with colour
82
+ const pixelData = imageData.data;
83
+ for (let i = 0; i < pixelData.length; ++i) {
84
+ if (mask.data[numMasks * i + bestIndex] === 1) {
85
+ const offset = 4 * i;
86
+ pixelData[offset] = 0; // red
87
+ pixelData[offset + 1] = 114; // green
88
+ pixelData[offset + 2] = 189; // blue
89
+ pixelData[offset + 3] = 255; // alpha
90
+ }
91
+ }
92
+
93
+ // Draw image data to context
94
+ context.putImageData(imageData, 0, 0);
95
+
96
+ } else if (type === 'segment_result') {
97
+ if (data === 'start') {
98
+ statusLabel.textContent = 'Extracting image embedding...';
99
+ } else {
100
+ statusLabel.textContent = 'Embedding extracted!';
101
+ isEncoded = true;
102
+ }
103
+ }
104
  });
105
 
106
+ function decode() {
107
+ isDecoding = true;
108
+ worker.postMessage({ type: 'decode', data: lastPoints });
109
+ }
110
+
111
+ function clearPointsAndMask() {
112
+ // Reset state
113
+ isMultiMaskMode = false;
114
+ lastPoints = null;
115
+
116
+ // Remove points from previous mask (if any)
117
+ document.querySelectorAll('.icon').forEach(e => e.remove());
118
+
119
+ // Disable cut button
120
+ cutButton.disabled = true;
121
+
122
+ // Reset mask canvas
123
+ maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height);
124
+ }
125
+ clearButton.addEventListener('click', clearPointsAndMask);
126
+
127
+ resetButton.addEventListener('click', () => {
128
+ // Update state
129
+ isEncoded = false;
130
+ imageDataURI = null;
131
+
132
+ // Indicate to worker that we have reset the state
133
+ worker.postMessage({ type: 'reset' });
134
+
135
+ // Clear points and mask (if present)
136
+ clearPointsAndMask();
137
+
138
+ // Update UI
139
+ cutButton.disabled = true;
140
+ imageContainer.style.backgroundImage = 'none';
141
+ uploadButton.style.display = 'flex';
142
+ statusLabel.textContent = 'Ready';
143
+ });
144
+
145
+ function segment(data) {
146
+ // Update state
147
+ isEncoded = false;
148
+ if (!modelReady) {
149
+ statusLabel.textContent = 'Loading model...';
150
+ }
151
+ imageDataURI = data;
152
+
153
+ // Update UI
154
+ imageContainer.style.backgroundImage = `url(${data})`;
155
+ uploadButton.style.display = 'none';
156
+ cutButton.disabled = true;
157
+
158
+ // Instruct worker to segment the image
159
+ worker.postMessage({ type: 'segment', data });
160
+ }
161
+
162
+ // Handle file selection
163
  fileUpload.addEventListener('change', function (e) {
164
  const file = e.target.files[0];
165
  if (!file) {
 
169
  const reader = new FileReader();
170
 
171
  // Set up a callback when the file is loaded
172
+ reader.onload = e2 => segment(e2.target.result);
173
 
174
  reader.readAsDataURL(file);
175
  });
176
 
177
+ example.addEventListener('click', (e) => {
178
+ e.preventDefault();
179
+ segment(EXAMPLE_URL);
180
+ });
181
 
182
+ function addIcon({ point, label }) {
183
+ const icon = (label === 1 ? star : cross).cloneNode();
184
+ icon.style.left = `${point[0] * 100}%`;
185
+ icon.style.top = `${point[1] * 100}%`;
186
+ imageContainer.appendChild(icon);
187
+ }
188
+
189
+ // Attach hover event to image container
190
+ imageContainer.addEventListener('mousedown', e => {
191
+ if (e.button !== 0 && e.button !== 2) {
192
+ return; // Ignore other buttons
193
+ }
194
+ if (!isEncoded) {
195
+ return; // Ignore if not encoded yet
196
+ }
197
+ if (!isMultiMaskMode) {
198
+ lastPoints = [];
199
+ isMultiMaskMode = true;
200
+ cutButton.disabled = false;
201
+ }
202
 
203
+ const point = getPoint(e);
204
+ lastPoints.push(point);
205
+
206
+ // add icon
207
+ addIcon(point);
208
+
209
+ decode();
210
+ });
211
+
212
+
213
+ // Clamp a value inside a range [min, max]
214
+ function clamp(x, min = 0, max = 1) {
215
+ return Math.max(Math.min(x, max), min)
216
  }
217
 
218
+ function getPoint(e) {
219
+ // Get bounding box
220
+ const bb = imageContainer.getBoundingClientRect();
221
+
222
+ // Get the mouse coordinates relative to the container
223
+ const mouseX = clamp((e.clientX - bb.left) / bb.width);
224
+ const mouseY = clamp((e.clientY - bb.top) / bb.height);
225
+
226
+ return {
227
+ point: [mouseX, mouseY],
228
+ label: e.button === 2 // right click
229
+ ? 0 // negative prompt
230
+ : 1, // positive prompt
231
+ }
 
 
 
 
 
 
 
 
 
 
 
 
232
  }
233
+
234
+ // Do not show context menu on right click
235
+ imageContainer.addEventListener('contextmenu', e => {
236
+ e.preventDefault();
237
+ });
238
+
239
+ // Attach hover event to image container
240
+ imageContainer.addEventListener('mousemove', e => {
241
+ if (!isEncoded || isMultiMaskMode) {
242
+ // Ignore mousemove events if the image is not encoded yet,
243
+ // or we are in multi-mask mode
244
+ return;
245
+ }
246
+ lastPoints = [getPoint(e)];
247
+
248
+ if (!isDecoding) {
249
+ decode(); // Only decode if we are not already decoding
250
+ }
251
+ });
252
+
253
+ // Handle cut button click
254
+ cutButton.addEventListener('click', () => {
255
+ const [w, h] = [maskCanvas.width, maskCanvas.height];
256
+
257
+ // Get the mask pixel data
258
+ const maskContext = maskCanvas.getContext('2d');
259
+ const maskPixelData = maskContext.getImageData(0, 0, w, h);
260
+
261
+ // Load the image
262
+ const image = new Image();
263
+ image.crossOrigin = 'anonymous';
264
+ image.onload = async () => {
265
+ // Create a new canvas to hold the image
266
+ const imageCanvas = new OffscreenCanvas(w, h);
267
+ const imageContext = imageCanvas.getContext('2d');
268
+ imageContext.drawImage(image, 0, 0, w, h);
269
+ const imagePixelData = imageContext.getImageData(0, 0, w, h);
270
+
271
+ // Create a new canvas to hold the cut-out
272
+ const cutCanvas = new OffscreenCanvas(w, h);
273
+ const cutContext = cutCanvas.getContext('2d');
274
+ const cutPixelData = cutContext.getImageData(0, 0, w, h);
275
+
276
+ // Copy the image pixel data to the cut canvas
277
+ for (let i = 3; i < maskPixelData.data.length; i += 4) {
278
+ if (maskPixelData.data[i] > 0) {
279
+ for (let j = 0; j < 4; ++j) {
280
+ const offset = i - j;
281
+ cutPixelData.data[offset] = imagePixelData.data[offset];
282
+ }
283
+ }
284
+ }
285
+ cutContext.putImageData(cutPixelData, 0, 0);
286
+
287
+ // Download image
288
+ const link = document.createElement('a');
289
+ link.download = 'image.png';
290
+ link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
291
+ link.click();
292
+ link.remove();
293
+ }
294
+ image.src = imageDataURI;
295
+ });
style.css CHANGED
@@ -1,76 +1,28 @@
1
- * {
2
- box-sizing: border-box;
3
- padding: 0;
4
- margin: 0;
5
- font-family: sans-serif;
6
- }
7
-
8
- html,
9
- body {
10
- height: 100%;
11
- }
12
-
13
  body {
14
- padding: 32px;
 
15
  }
16
 
17
- body,
18
- #container {
19
- display: flex;
20
- flex-direction: column;
21
- justify-content: center;
22
- align-items: center;
23
  }
24
 
25
- #container {
26
- position: relative;
27
- gap: 0.4rem;
28
-
29
- width: 640px;
30
- height: 640px;
31
- max-width: 100%;
32
- max-height: 100%;
33
-
34
- border: 2px dashed #D1D5DB;
35
- border-radius: 0.75rem;
36
- overflow: hidden;
37
- cursor: pointer;
38
- margin: 1rem;
39
-
40
- background-size: 100% 100%;
41
- background-position: center;
42
- background-repeat: no-repeat;
43
- font-size: 18px;
44
  }
45
 
46
- #upload {
47
- display: none;
 
 
 
 
48
  }
49
 
50
- svg {
51
- pointer-events: none;
52
  }
53
-
54
- #example {
55
- font-size: 14px;
56
- text-decoration: underline;
57
- cursor: pointer;
58
- }
59
-
60
- #example:hover {
61
- color: #2563EB;
62
- }
63
-
64
- .bounding-box {
65
- position: absolute;
66
- box-sizing: border-box;
67
- border: solid 2px;
68
- }
69
-
70
- .bounding-box-label {
71
- color: white;
72
- position: absolute;
73
- font-size: 12px;
74
- margin: -16px 0 0 -2px;
75
- padding: 1px;
76
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  body {
2
+ padding: 2rem;
3
+ font-family: -apple-system, BlinkMacSystemFont, "Arial", sans-serif;
4
  }
5
 
6
+ h1 {
7
+ font-size: 16px;
8
+ margin-top: 0;
 
 
 
9
  }
10
 
11
+ p {
12
+ color: rgb(107, 114, 128);
13
+ font-size: 15px;
14
+ margin-bottom: 10px;
15
+ margin-top: 5px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  }
17
 
18
+ .card {
19
+ max-width: 620px;
20
+ margin: 0 auto;
21
+ padding: 16px;
22
+ border: 1px solid lightgray;
23
+ border-radius: 16px;
24
  }
25
 
26
+ .card p:last-child {
27
+ margin-bottom: 0;
28
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
worker.js ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { env, SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@xenova/[email protected]';
2
+
3
+ // Since we will download the model from the Hugging Face Hub, we can skip the local model check
4
+ env.allowLocalModels = false;
5
+
6
+ // We adopt the singleton pattern to enable lazy-loading of the model and processor.
7
+ export class SegmentAnythingSingleton {
8
+ static model_id = 'Xenova/slimsam-77-uniform';
9
+ static model;
10
+ static processor;
11
+ static quantized = true;
12
+
13
+ static getInstance() {
14
+ if (!this.model) {
15
+ this.model = SamModel.from_pretrained(this.model_id, {
16
+ quantized: this.quantized,
17
+ });
18
+ }
19
+ if (!this.processor) {
20
+ this.processor = AutoProcessor.from_pretrained(this.model_id);
21
+ }
22
+
23
+ return Promise.all([this.model, this.processor]);
24
+ }
25
+ }
26
+
27
+
28
+ // State variables
29
+ let image_embeddings = null;
30
+ let image_inputs = null;
31
+ let ready = false;
32
+
33
+ self.onmessage = async (e) => {
34
+ const [model, processor] = await SegmentAnythingSingleton.getInstance();
35
+ if (!ready) {
36
+ // Indicate that we are ready to accept requests
37
+ ready = true;
38
+ self.postMessage({
39
+ type: 'ready',
40
+ });
41
+ }
42
+
43
+ const { type, data } = e.data;
44
+ if (type === 'reset') {
45
+ image_inputs = null;
46
+ image_embeddings = null;
47
+
48
+ } else if (type === 'segment') {
49
+ // Indicate that we are starting to segment the image
50
+ self.postMessage({
51
+ type: 'segment_result',
52
+ data: 'start',
53
+ });
54
+
55
+ // Read the image and recompute image embeddings
56
+ const image = await RawImage.read(e.data.data);
57
+ image_inputs = await processor(image);
58
+ image_embeddings = await model.get_image_embeddings(image_inputs)
59
+
60
+ // Indicate that we have computed the image embeddings, and we are ready to accept decoding requests
61
+ self.postMessage({
62
+ type: 'segment_result',
63
+ data: 'done',
64
+ });
65
+
66
+ } else if (type === 'decode') {
67
+ // Prepare inputs for decoding
68
+ const reshaped = image_inputs.reshaped_input_sizes[0];
69
+ const points = data.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
70
+ const labels = data.map(x => BigInt(x.label));
71
+
72
+ const input_points = new Tensor(
73
+ 'float32',
74
+ points.flat(Infinity),
75
+ [1, 1, points.length, 2],
76
+ )
77
+ const input_labels = new Tensor(
78
+ 'int64',
79
+ labels.flat(Infinity),
80
+ [1, 1, labels.length],
81
+ )
82
+
83
+ // Generate the mask
84
+ const outputs = await model({
85
+ ...image_embeddings,
86
+ input_points,
87
+ input_labels,
88
+ })
89
+
90
+ // Post-process the mask
91
+ const masks = await processor.post_process_masks(
92
+ outputs.pred_masks,
93
+ image_inputs.original_sizes,
94
+ image_inputs.reshaped_input_sizes,
95
+ );
96
+
97
+ // Send the result back to the main thread
98
+ self.postMessage({
99
+ type: 'decode_result',
100
+ data: {
101
+ mask: RawImage.fromTensor(masks[0][0]),
102
+ scores: outputs.iou_scores.data,
103
+ },
104
+ });
105
+
106
+ } else {
107
+ throw new Error(`Unknown message type: ${type}`);
108
+ }
109
+ }