tjw commited on
Commit
ba529ff
1 Parent(s): 5856612
Files changed (5) hide show
  1. aimodel.py +404 -0
  2. environment.yml +435 -0
  3. main.py +57 -0
  4. readme.txt +5 -0
  5. test_rect.py +141 -0
aimodel.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import matplotlib.style
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
+ from PIL import Image
5
+ import pickle
6
+ import torch
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ from PIL import ImageDraw
10
+ from IPython.display import display
11
+ import numpy as np
12
+ from collections import namedtuple
13
+ from logging import getLogger
14
+ logger = getLogger(__name__)
15
+ # %%
16
+ class Florence:
17
+ def __init__(self, model_id:str, hack=False):
18
+ if hack:
19
+ return
20
+ self.model = (
21
+ AutoModelForCausalLM.from_pretrained(
22
+ model_id, trust_remote_code=True, torch_dtype="auto"
23
+ )
24
+ .eval()
25
+ .cuda()
26
+ )
27
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
28
+ self.model_id = model_id
29
+ def run(self, img:Image, task_prompt:str, extra_text:str|None=None):
30
+ logger.debug(f"run {task_prompt} {extra_text}")
31
+ model, processor = self.model, self.processor
32
+ prompt = task_prompt + (extra_text if extra_text else "")
33
+ inputs = processor(text=prompt, images=img, return_tensors="pt").to(
34
+ "cuda", torch.float16
35
+ )
36
+ generated_ids = model.generate(
37
+ input_ids=inputs["input_ids"],
38
+ pixel_values=inputs["pixel_values"],
39
+ max_new_tokens=1024,
40
+ early_stopping=False,
41
+ do_sample=False,
42
+ num_beams=3,
43
+ #temperature=0.1,
44
+ )
45
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
46
+ parsed_answer = processor.post_process_generation(
47
+ generated_text,
48
+ task=task_prompt,
49
+ image_size=(img.width, img.height),
50
+ )
51
+ return parsed_answer
52
+ def model_init(hack=False):
53
+ fl = Florence("microsoft/Florence-2-large", hack=hack)
54
+ fl_ft = Florence("microsoft/Florence-2-large-ft", hack=hack)
55
+ return fl, fl_ft
56
+ #%%
57
+ # florence-2 tasks
58
+ TASK_OD = "<OD>"
59
+ TASK_SEGMENTATION = '<REFERRING_EXPRESSION_SEGMENTATION>'
60
+ TASK_CAPTION = "<CAPTION_TO_PHRASE_GROUNDING>"
61
+ TASK_OCR = "<OCR_WITH_REGION>"
62
+ TASK_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
63
+ #%%
64
+ AIModelResult = namedtuple('AIModelResult',
65
+ ['img', 'img2', 'meter_bbox', 'needle_polygons', 'circle_polygons', 'ocr1', 'ocr2'])
66
+ cached_results:dict[str, AIModelResult] = {}
67
+
68
+ #%%
69
+ def get_meter_bbox(fl:Florence, img:Image):
70
+ task_prompt, extra_text = TASK_GROUNDING, "a circular meter with white background"
71
+ parsed_answer = fl.run(img, task_prompt, extra_text)
72
+ assert len(parsed_answer) == 1
73
+ k,v = parsed_answer.popitem()
74
+ assert 'bboxes' in v
75
+ assert 'labels' in v
76
+ assert len(v['bboxes']) == 1
77
+ assert len(v['labels']) == 1
78
+ assert v['labels'][0] == 'a circular meter'
79
+ bbox = v['bboxes'][0]
80
+ return bbox
81
+
82
+ def get_circles(fl:Florence, img2:Image, polygons:list):
83
+ img3 = Image.new('L', img2.size, color = 'black')
84
+ draw = ImageDraw.Draw(img3)
85
+ for polygon in polygons:
86
+ draw.polygon(polygon, outline='white', width=3, fill='white')
87
+ img2a = np.where(np.array(img3)[:,:,None]>0, np.array(img2), 255)
88
+ img4 = Image.fromarray(img2a)
89
+ parsed_answer = fl.run(img4, TASK_SEGMENTATION, "a circle")
90
+ assert len(parsed_answer) == 1
91
+ k,v = parsed_answer.popitem()
92
+ assert 'polygons' in v
93
+ assert len(v['polygons']) == 1
94
+ return v['polygons'][0]
95
+
96
+ def get_needle_polygons(fl:Florence, img2:Image):
97
+ parsed_answer = fl.run(img2, TASK_SEGMENTATION, "the long narrow black needle hand pass through the center of the cicular meter")
98
+ assert len(parsed_answer) == 1
99
+ k,v = parsed_answer.popitem()
100
+ assert 'polygons' in v
101
+ assert len(v['polygons']) == 1
102
+ needle_polygons = v['polygons'][0]
103
+ return needle_polygons
104
+
105
+ def get_ocr(fl:Florence, img2:Image):
106
+ parsed_answer = fl.run(img2, TASK_OCR)
107
+ assert len(parsed_answer)==1
108
+ k,v = parsed_answer.popitem()
109
+ return v
110
+
111
+ def get_ai_model_result(img:Image.Image|Path|str, fl:Florence, fl_ft:Florence):
112
+ if isinstance(img, Path):
113
+ key = img.parts[-1]
114
+ elif isinstance(img, str):
115
+ key = img.split('/')[-1]
116
+ else:
117
+ key = None
118
+ if key is not None and key in cached_results:
119
+ return cached_results[key]
120
+ if isinstance(img, (Path, str)):
121
+ img = Image.open(img)
122
+ meter_bbox = get_meter_bbox(fl, img)
123
+ img2 = img.crop(meter_bbox)
124
+ needle_polygons = get_needle_polygons(fl, img2)
125
+ result = AIModelResult(img, img2, meter_bbox, needle_polygons,
126
+ get_circles(fl, img2, needle_polygons),
127
+ get_ocr(fl, img2),
128
+ get_ocr(fl_ft, img2)
129
+ )
130
+ if key is not None:
131
+ cached_results[key] = result
132
+ return result
133
+ #%%
134
+ from skimage.measure import regionprops
135
+ from skimage.measure import EllipseModel
136
+ from skimage.draw import ellipse_perimeter
137
+ def get_regionprops(polygons:list) -> regionprops:
138
+ coords = np.concatenate(polygons).reshape(-1, 2)
139
+ size = tuple( (coords.max(axis=0)+2).astype('int') )
140
+ img = Image.new('L', size, color = 'black')
141
+ # draw circle polygon
142
+ draw = ImageDraw.Draw(img)
143
+ for polygon in polygons:
144
+ draw.polygon(polygon, outline='white', width=1, fill='white')
145
+ # use skimage to find the mass center of the circle
146
+ circle_imga = (np.array(img)>0).astype(np.uint8)
147
+ property = regionprops(circle_imga)[0]
148
+ return property
149
+ def estimate_ellipse(coords, enlarge_factor=1.0):
150
+ em = EllipseModel()
151
+ em.estimate(coords[:, ::-1])
152
+ y, x, a, b, theta = em.params
153
+ a, b = a*enlarge_factor, b*enlarge_factor
154
+ em_params = np.round([y,x, a, b]).astype('int')
155
+ c, r = ellipse_perimeter(*em_params, orientation=-theta)
156
+ return em_params, theta, (c, r)
157
+ def estimate_line(coords):
158
+ lm = LineModelND()
159
+ lm.estimate(coords)
160
+ return lm.params
161
+ #%%
162
+ #%%
163
+ from matplotlib import pyplot as plt
164
+ import matplotlib
165
+ from skimage.measure import LineModelND, ransac
166
+ matplotlib.style.use('dark_background')
167
+ def rotate_theta(theta):
168
+ return ((theta + 3*np.pi/2)%(2*np.pi))/(2*np.pi)*360
169
+ kg_cm2_labels = list(map(str, [1,3,5,7,9,11]))
170
+ psi_labels = list(map(str, range(20, 180, 20)))
171
+
172
+ # lousy decoupling
173
+ MeterResult = namedtuple('MeterResult', [
174
+ 'result',
175
+ 'needle_psi',
176
+ 'needle_kg_cm2',
177
+ 'needle_theta',
178
+ 'orign',
179
+ 'direction',
180
+ 'center',
181
+
182
+ 'lm',
183
+ 'inliers',
184
+
185
+ 'kg_cm2_texts',
186
+ 'psi_texts',
187
+ 'kg_cm2_centers',
188
+ 'psi_centers',
189
+ 'kg_cm2_theta',
190
+ 'psi_theta',
191
+ 'kg_cm2_psi',
192
+ 'psi' ,
193
+ ])
194
+
195
+ def read_meter(img:Image.Image|str|Path, fl, fl_ft):
196
+ # ai model results
197
+ result = get_ai_model_result(img, fl, fl_ft)
198
+
199
+ # needle direction
200
+ coords = np.concatenate(result.needle_polygons).reshape(-1, 2)
201
+ orign, direction = estimate_line(coords)
202
+
203
+ # calculate the meter center
204
+ circle_props = get_regionprops(result.circle_polygons)
205
+ center = circle_props.centroid[::-1]
206
+
207
+ # XXX: the needle direction is from center to orign
208
+ if (orign - center) @ direction < 0:
209
+ direction = -direction
210
+
211
+ # calculate the needle theta
212
+ needle_theta = rotate_theta(np.arctan2(direction[1], direction[0]))
213
+
214
+ # calulate ocr texts to find kg/cm2 and psi labels
215
+ ocr1, ocr2 = result.ocr1, result.ocr2
216
+ kg_cm2_texts = {}
217
+ psi_texts = {}
218
+ quad_boxes = ocr1['quad_boxes']+ocr2['quad_boxes']
219
+ labels = ocr1['labels']+ocr2['labels']
220
+ for qbox, label in zip(quad_boxes, labels):
221
+ if label in kg_cm2_labels:
222
+ kg_cm2_texts[int(label)]=qbox
223
+ if label in psi_labels:
224
+ psi_texts[int(label)]=qbox
225
+ # calculate the center of kg/cm2 and psi labels
226
+ kg_cm2_centers = np.array(list(kg_cm2_texts.values())).reshape(-1, 4, 2).mean(axis=1)
227
+ psi_centers = np.array(list(psi_texts.values())).reshape(-1, 4, 2).mean(axis=1)
228
+
229
+ # convert kg/cm2 and psi labels to polar coordinates, origin is the center of the meter
230
+ # the angle is in degree which is more intuitive
231
+ kg_cm2_coords = kg_cm2_centers - center
232
+ kg_cm2_theta = rotate_theta(np.arctan2(kg_cm2_coords[:, 1], kg_cm2_coords[:, 0]))
233
+ psi_coords = psi_centers - center
234
+ psi_theta = rotate_theta(np.arctan2(psi_coords[:, 1], psi_coords[:, 0]))
235
+
236
+ # convert kg_cm2 to psi for fitting a line model
237
+ kg_cm2 = np.array(list(kg_cm2_texts.keys()))
238
+ kg_cm2_psi = kg_cm2 * 14.223
239
+ # combine kg/cm2 and psi labels to fit a line model
240
+ psi = np.array(list(psi_texts.keys()))
241
+ Y = np.concatenate([kg_cm2_psi, psi])
242
+ X = np.concatenate([kg_cm2_theta, psi_theta])
243
+ data = np.stack([X, Y], axis=1)
244
+ # run ransac to robustly fit a line model
245
+ lm, inliers = ransac(data, LineModelND, min_samples=2,
246
+ residual_threshold=15,
247
+ max_trials=2)
248
+
249
+ # use the model to calculated the needle psi and kg/cm2
250
+ needle_psi = lm.predict(needle_theta)[1]
251
+ needle_kg_cm2 = needle_psi / 14.223
252
+
253
+ return MeterResult(result=result,
254
+ needle_psi=needle_psi,
255
+ needle_kg_cm2=needle_kg_cm2,
256
+ needle_theta=needle_theta,
257
+ orign=orign,
258
+ direction=direction,
259
+ center=center,
260
+ lm=lm,
261
+ inliers=data[inliers].T,
262
+ kg_cm2_texts=kg_cm2_texts,
263
+ psi_texts=psi_texts,
264
+ kg_cm2_centers=kg_cm2_centers,
265
+ psi_centers=psi_centers,
266
+ kg_cm2_theta=kg_cm2_theta,
267
+ psi_theta=psi_theta,
268
+ kg_cm2_psi=kg_cm2_psi,
269
+ psi=psi,
270
+ )
271
+
272
+
273
+ def more_visualization_data(meter_result:MeterResult):
274
+ result = meter_result.result
275
+ center = meter_result.center
276
+ # following calculations are for visualization and debugging
277
+ # calculate the needle head(farest point from center)
278
+ needle_coordinates = np.concatenate(result.needle_polygons).reshape(-1, 2)
279
+ needle_length = np.linalg.norm(needle_coordinates - center,axis=1)
280
+ farest_idx = np.argmax(needle_length)
281
+ needle_head = needle_coordinates[farest_idx]
282
+ needle_head_length = needle_length[farest_idx]
283
+ direction = meter_result.direction * needle_head_length
284
+
285
+ # inliners data
286
+ inlier_theta, inlier_psi = meter_result.inliers
287
+
288
+ # predict psi from 0 to 360
289
+ predict_theta = np.linspace(0, 360, 100)
290
+ predict_psi = meter_result.lm.predict(predict_theta)[:, 1]
291
+ return inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction
292
+
293
+ def visualization(meter_result:MeterResult):
294
+ result = meter_result.result
295
+ center = meter_result.center
296
+ needle_psi, needle_kg_cm2 = meter_result.needle_psi, meter_result.needle_kg_cm2
297
+ inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction = more_visualization_data(meter_result)
298
+ # drawing and visualization
299
+ draw = ImageDraw.Draw(result.img2.copy())
300
+ # draw needle polygons
301
+ for polygon in result.needle_polygons:
302
+ draw.polygon(polygon, outline='red', width=3)
303
+
304
+ # draw center circle
305
+ draw = ImageDraw.Draw(draw._image.convert('RGBA'))
306
+
307
+ draw2 = ImageDraw.Draw(Image.new('RGBA', draw._image.size, (0,0,0,0)))
308
+ for polygon in result.circle_polygons:
309
+ draw2.polygon(polygon, outline='purple', width=1, fill = (255,128,255,100))
310
+ img = Image.alpha_composite(draw._image, draw2._image)
311
+ draw = ImageDraw.Draw(img.convert('RGB'))
312
+
313
+ # draw needle direction
314
+ draw.line((center[0], center[1], center[0]+direction[0], center[1]+direction[1]), fill='yellow', width=3)
315
+ # draw a dot at center
316
+ draw.ellipse((center[0]-5, center[1]-5, center[0]+5, center[1]+5), outline='yellow', width=3)
317
+ # draw a dot at needle_head
318
+ draw.ellipse((needle_head[0]-5, needle_head[1]-5, needle_head[0]+5, needle_head[1]+5), outline='yellow', width=3)
319
+
320
+ for x,y in meter_result.kg_cm2_centers:
321
+ draw.ellipse((x-3, y-3, x+3, y+3), outline='blue', width=3)
322
+ for x,y in meter_result.psi_centers:
323
+ draw.ellipse((x-3, y-3, x+3, y+3), outline='green', width=3)
324
+ for label,quad_box in meter_result.kg_cm2_texts.items():
325
+ draw.polygon(quad_box, outline='blue', width=1)
326
+ draw.text((quad_box[0], quad_box[1]-10), str(label), fill='blue', anchor='ls')
327
+ for label,quad_box in meter_result.psi_texts.items():
328
+ draw.polygon(quad_box, outline='green', width=1)
329
+ draw.text((quad_box[0], quad_box[1]-10), str(label), fill='green', anchor='ls')
330
+
331
+ if len(meter_result.kg_cm2_centers) >4:
332
+ # the ellipse of kg/cm2 labels, currently only for visualization
333
+ em_params, theta, (c, r) = estimate_ellipse(meter_result.kg_cm2_centers)
334
+ y, x = em_params[:2]
335
+ draw.ellipse((x-5, y-5, x+5, y+5), outline='blue', width=1)
336
+ imga = np.array(draw._image)
337
+ imga[c,r] = (0, 0, 255)
338
+ draw = ImageDraw.Draw(Image.fromarray(imga))
339
+
340
+ if len(meter_result.psi_centers) >4:
341
+ # the ellipse of psi labels, currently only for visualization
342
+ em_params, theta, (c, r) = estimate_ellipse(meter_result.psi_centers)
343
+ draw.ellipse((x-5, y-5, x+5, y+5), outline='green', width=1)
344
+ imga = np.array(draw._image)
345
+ imga[c,r] = (0, 255, 0)
346
+ y, x = em_params[:2]
347
+ draw = ImageDraw.Draw(Image.fromarray(imga))
348
+ draw.text((needle_head[0]-10, needle_head[1]-10),
349
+ f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}',anchor='ls',
350
+ fill='yellow')
351
+ plt.plot(predict_theta, predict_psi, color='red', alpha=0.5)
352
+ plt.plot(meter_result.kg_cm2_theta, meter_result.kg_cm2_psi, 'o', color='#77F')
353
+ plt.plot(meter_result.psi_theta, meter_result.psi, 'o', color='#7F7')
354
+ plt.plot(inlier_theta, inlier_psi, 'x', color='red', alpha=0.5)
355
+ plt.vlines(meter_result.needle_theta, 0, 160, colors='yellow', alpha=0.5)
356
+ plt.hlines(meter_result.needle_psi, 0, 360, colors='yellow', alpha=0.5)
357
+
358
+ plt.text(meter_result.needle_theta-20, meter_result.needle_psi-20,
359
+ f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}', color='yellow')
360
+ plt.xlim(0, 360)
361
+ plt.ylim(0, 160)
362
+ return draw._image, plt.gcf()
363
+
364
+ def clear_cache():
365
+ cached_results.clear()
366
+ def save_cache():
367
+ pickle.dump(cached_results, open('cached_results.pkl', 'wb'))
368
+ def load_cache():
369
+ global cached_results
370
+ cached_results = pickle.load(open('cached_results.pkl', 'rb'))
371
+ #%%
372
+ if __name__ == '__main__':
373
+ from io import BytesIO
374
+ fl, fl_ft = model_init(hack=False)
375
+ #load_cache()
376
+ clear_cache()
377
+ imgs = list(Path('images/good').glob('*.jpg'))#[-1:]
378
+ W, H = 640, 480
379
+ for img_fn in imgs:
380
+ print(img_fn)
381
+ meter_result = read_meter(img_fn, fl, fl_ft)
382
+ img, fig = visualization(meter_result)
383
+ # resize draw._image to fit WxH and keep aspect ratio
384
+ w, h = meter_result.result.img2.size
385
+ if w/W > h/H:
386
+ w, h = W, int(h*W/w)
387
+ else:
388
+ w, h = int(w*H/h), H
389
+ display(img.resize((w, h)))
390
+ # convert figure to PIL image using io.BytesIO
391
+ buf = BytesIO()
392
+ fig.savefig(buf, format='png')
393
+ buf.seek(0)
394
+ fig_img = Image.open(buf)
395
+ display(fig_img)
396
+ # clear plot
397
+ plt.clf()
398
+
399
+
400
+
401
+
402
+
403
+
404
+ # %%
environment.yml ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: guage
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=conda_forge
9
+ - _openmp_mutex=4.5=2_gnu
10
+ - aom=3.6.1=h59595ed_0
11
+ - binutils_impl_linux-64=2.43=h4bf12b8_2
12
+ - binutils_linux-64=2.43=h4852527_2
13
+ - blas=1.0=mkl
14
+ - blosc=1.21.6=hef167b5_0
15
+ - brotli=1.1.0=hb9d3cd8_2
16
+ - brotli-bin=1.1.0=hb9d3cd8_2
17
+ - brotli-python=1.1.0=py311hfdbb021_2
18
+ - brunsli=0.1=h9c3ff4c_0
19
+ - bzip2=1.0.8=h4bc722e_7
20
+ - c-blosc2=2.14.4=hb4ffafa_1
21
+ - ca-certificates=2024.8.30=hbcca054_0
22
+ - certifi=2024.8.30=pyhd8ed1ab_0
23
+ - cffi=1.17.1=py311hf29c0ef_0
24
+ - charls=2.4.2=h59595ed_0
25
+ - charset-normalizer=3.4.0=pyhd8ed1ab_0
26
+ - cpython=3.11.10=py311hd8ed1ab_3
27
+ - cuda-cccl=12.6.77=0
28
+ - cuda-cccl_linux-64=12.6.77=0
29
+ - cuda-command-line-tools=12.1.1=0
30
+ - cuda-compiler=12.6.2=0
31
+ - cuda-crt-dev_linux-64=12.6.20=0
32
+ - cuda-crt-tools=12.6.20=0
33
+ - cuda-cudart=12.1.105=0
34
+ - cuda-cudart-dev=12.1.105=0
35
+ - cuda-cudart-dev_linux-64=12.6.77=0
36
+ - cuda-cudart-static=12.6.77=0
37
+ - cuda-cudart-static_linux-64=12.6.77=0
38
+ - cuda-cudart_linux-64=12.6.77=0
39
+ - cuda-cuobjdump=12.6.77=0
40
+ - cuda-cupti=12.1.105=0
41
+ - cuda-cuxxfilt=12.6.77=0
42
+ - cuda-documentation=12.4.127=0
43
+ - cuda-driver-dev=12.6.77=0
44
+ - cuda-driver-dev_linux-64=12.6.77=0
45
+ - cuda-gdb=12.6.77=0
46
+ - cuda-libraries=12.1.0=0
47
+ - cuda-libraries-dev=12.6.2=0
48
+ - cuda-libraries-static=12.6.2=0
49
+ - cuda-nsight=12.6.77=0
50
+ - cuda-nvcc=12.6.20=0
51
+ - cuda-nvcc-dev_linux-64=12.6.20=0
52
+ - cuda-nvcc-impl=12.6.20=0
53
+ - cuda-nvcc-tools=12.6.20=0
54
+ - cuda-nvcc_linux-64=12.6.20=0
55
+ - cuda-nvdisasm=12.6.77=0
56
+ - cuda-nvml-dev=12.6.77=2
57
+ - cuda-nvprof=12.6.80=0
58
+ - cuda-nvprune=12.6.77=0
59
+ - cuda-nvrtc=12.1.105=0
60
+ - cuda-nvrtc-dev=12.1.105=0
61
+ - cuda-nvrtc-static=12.6.77=0
62
+ - cuda-nvtx=12.1.105=0
63
+ - cuda-nvvm-dev_linux-64=12.6.20=0
64
+ - cuda-nvvm-impl=12.6.20=0
65
+ - cuda-nvvm-tools=12.6.20=0
66
+ - cuda-nvvp=12.6.80=0
67
+ - cuda-opencl=12.6.77=0
68
+ - cuda-opencl-dev=12.6.77=0
69
+ - cuda-profiler-api=12.6.77=0
70
+ - cuda-runtime=12.1.0=0
71
+ - cuda-sanitizer-api=12.6.77=0
72
+ - cuda-toolkit=12.1.0=0
73
+ - cuda-tools=12.1.1=0
74
+ - cuda-version=12.6=3
75
+ - cuda-visual-tools=12.6.2=0
76
+ - dav1d=1.2.1=hd590300_0
77
+ - dbus=1.13.18=hb2f20db_0
78
+ - expat=2.6.4=h5888daf_0
79
+ - ffmpeg=4.4.0=h6987444_4
80
+ - filelock=3.16.1=pyhd8ed1ab_0
81
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
82
+ - font-ttf-inconsolata=3.000=h77eed37_0
83
+ - font-ttf-source-code-pro=2.038=h77eed37_0
84
+ - font-ttf-ubuntu=0.83=h77eed37_3
85
+ - fontconfig=2.15.0=h7e30c49_1
86
+ - fonts-conda-ecosystem=1=0
87
+ - fonts-conda-forge=1=0
88
+ - freetype=2.12.1=h267a509_2
89
+ - gcc_impl_linux-64=12.4.0=hb2e57f8_1
90
+ - gcc_linux-64=12.4.0=h6b7512a_5
91
+ - gds-tools=1.11.1.6=0
92
+ - gettext=0.22.5=he02047a_3
93
+ - gettext-tools=0.22.5=he02047a_3
94
+ - giflib=5.2.2=hd590300_0
95
+ - glib=2.82.2=h44428e9_0
96
+ - glib-tools=2.82.2=h4833e2c_0
97
+ - gmp=6.3.0=hac33072_2
98
+ - gmpy2=2.1.5=py311h0f6cedb_2
99
+ - gnutls=3.6.13=h85f3911_1
100
+ - gxx_impl_linux-64=12.4.0=h613a52c_1
101
+ - gxx_linux-64=12.4.0=h8489865_5
102
+ - h2=4.1.0=pyhd8ed1ab_0
103
+ - hpack=4.0.0=pyh9f0ad1d_0
104
+ - hyperframe=6.0.1=pyhd8ed1ab_0
105
+ - idna=3.10=pyhd8ed1ab_0
106
+ - imagecodecs=2024.1.1=py311hbe88301_6
107
+ - imageio=2.36.0=pyh12aca89_1
108
+ - importlib-metadata=8.5.0=pyha770c72_0
109
+ - intel-openmp=2022.0.1=h06a4308_3633
110
+ - jinja2=3.1.4=pyhd8ed1ab_0
111
+ - jxrlib=1.1=hd590300_3
112
+ - kernel-headers_linux-64=3.10.0=he073ed8_18
113
+ - lame=3.100=h166bdaf_1003
114
+ - lazy-loader=0.4=pyhd8ed1ab_1
115
+ - lazy_loader=0.4=pyhd8ed1ab_1
116
+ - lcms2=2.16=hb7c19ff_0
117
+ - ld_impl_linux-64=2.43=h712a8e2_2
118
+ - lerc=4.0.0=h27087fc_0
119
+ - libaec=1.1.3=h59595ed_0
120
+ - libasprintf=0.22.5=he8f35ee_3
121
+ - libasprintf-devel=0.22.5=he8f35ee_3
122
+ - libavif16=1.0.1=h87da1f6_2
123
+ - libblas=3.9.0=16_linux64_mkl
124
+ - libbrotlicommon=1.1.0=hb9d3cd8_2
125
+ - libbrotlidec=1.1.0=hb9d3cd8_2
126
+ - libbrotlienc=1.1.0=hb9d3cd8_2
127
+ - libcblas=3.9.0=16_linux64_mkl
128
+ - libcublas=12.1.0.26=0
129
+ - libcublas-dev=12.1.0.26=0
130
+ - libcublas-static=12.6.3.3=0
131
+ - libcufft=11.0.2.4=0
132
+ - libcufft-dev=11.0.2.4=0
133
+ - libcufft-static=11.3.0.4=0
134
+ - libcufile=1.11.1.6=0
135
+ - libcufile-dev=1.11.1.6=0
136
+ - libcufile-static=1.11.1.6=0
137
+ - libcurand=10.3.7.77=0
138
+ - libcurand-dev=10.3.7.77=0
139
+ - libcurand-static=10.3.7.77=0
140
+ - libcusolver=11.4.4.55=0
141
+ - libcusolver-dev=11.4.4.55=0
142
+ - libcusolver-static=11.7.1.2=0
143
+ - libcusparse=12.0.2.55=0
144
+ - libcusparse-dev=12.0.2.55=0
145
+ - libcusparse-static=12.5.4.2=0
146
+ - libdeflate=1.20=hd590300_0
147
+ - libdrm=2.4.123=hb9d3cd8_0
148
+ - libegl=1.7.0=ha4b6fd6_1
149
+ - libexpat=2.6.4=h5888daf_0
150
+ - libffi=3.4.2=h7f98852_5
151
+ - libgcc=14.2.0=h77fa898_1
152
+ - libgcc-devel_linux-64=12.4.0=ha4f9413_101
153
+ - libgcc-ng=14.2.0=h69a702a_1
154
+ - libgettextpo=0.22.5=he02047a_3
155
+ - libgettextpo-devel=0.22.5=he02047a_3
156
+ - libgfortran=14.2.0=h69a702a_1
157
+ - libgfortran5=14.2.0=hd5240d6_1
158
+ - libgl=1.7.0=ha4b6fd6_1
159
+ - libglib=2.82.2=h2ff4ddf_0
160
+ - libglvnd=1.7.0=ha4b6fd6_1
161
+ - libglx=1.7.0=ha4b6fd6_1
162
+ - libgomp=14.2.0=h77fa898_1
163
+ - libhwy=1.1.0=h00ab1b0_0
164
+ - libiconv=1.17=hd590300_2
165
+ - libidn2=2.3.7=hd590300_0
166
+ - libjpeg-turbo=3.0.0=hd590300_1
167
+ - libjxl=0.10.3=h66b40c8_0
168
+ - liblapack=3.9.0=16_linux64_mkl
169
+ - libnpp=12.0.2.50=0
170
+ - libnpp-dev=12.0.2.50=0
171
+ - libnpp-static=12.3.1.54=0
172
+ - libnsl=2.0.1=hd590300_0
173
+ - libnvfatbin=12.6.77=0
174
+ - libnvfatbin-dev=12.6.77=0
175
+ - libnvfatbin-static=12.6.77=0
176
+ - libnvjitlink=12.1.105=0
177
+ - libnvjitlink-dev=12.1.105=0
178
+ - libnvjitlink-static=12.6.77=0
179
+ - libnvjpeg=12.1.1.14=0
180
+ - libnvjpeg-dev=12.1.1.14=0
181
+ - libnvjpeg-static=12.3.3.54=0
182
+ - libnvvm-samples=12.1.105=0
183
+ - libpciaccess=0.18=hd590300_0
184
+ - libpng=1.6.44=hadc24fc_0
185
+ - libsanitizer=12.4.0=h46f95d5_1
186
+ - libsqlite=3.47.0=hadc24fc_1
187
+ - libstdcxx=14.2.0=hc0a3c3a_1
188
+ - libstdcxx-devel_linux-64=12.4.0=ha4f9413_101
189
+ - libstdcxx-ng=14.2.0=h4852527_1
190
+ - libtasn1=4.19.0=h166bdaf_0
191
+ - libtiff=4.6.0=h1dd3fc0_3
192
+ - libunistring=0.9.10=h7f98852_0
193
+ - libuuid=2.38.1=h0b41bf4_0
194
+ - libva=2.22.0=h8a09558_1
195
+ - libvpx=1.11.0=h9c3ff4c_3
196
+ - libwebp=1.4.0=h2c329e2_0
197
+ - libwebp-base=1.4.0=hd590300_0
198
+ - libxcb=1.17.0=h8a09558_0
199
+ - libxcrypt=4.4.36=hd590300_1
200
+ - libxkbcommon=1.7.0=h2c5496b_1
201
+ - libxml2=2.13.4=h064dc61_2
202
+ - libzlib=1.3.1=hb9d3cd8_2
203
+ - libzopfli=1.0.3=h9c3ff4c_0
204
+ - llvm-openmp=15.0.7=h0cdce71_0
205
+ - lz4-c=1.9.4=hcb278e6_0
206
+ - mkl=2022.1.0=hc2b9512_224
207
+ - mpc=1.3.1=h24ddda3_1
208
+ - mpfr=4.2.1=h90cbb55_3
209
+ - mpmath=1.3.0=pyhd8ed1ab_0
210
+ - ncurses=6.5=he02047a_1
211
+ - nettle=3.6=he412f7d_0
212
+ - networkx=3.4.2=pyhd8ed1ab_1
213
+ - nsight-compute=2024.3.2.3=0
214
+ - nspr=4.36=h5888daf_0
215
+ - nss=3.106=hdf54f9c_0
216
+ - numpy=1.26.4=py311h64a7726_0
217
+ - openh264=2.1.1=h4ff587b_0
218
+ - openjpeg=2.5.2=h488ebb8_0
219
+ - openssl=3.3.2=hb9d3cd8_0
220
+ - p11-kit=0.24.1=hc5aa10d_0
221
+ - packaging=24.1=pyhd8ed1ab_0
222
+ - pcre2=10.44=hba22ea6_2
223
+ - pillow=10.4.0=py311h4aec55e_1
224
+ - pip=24.3.1=pyh8b19718_0
225
+ - pthread-stubs=0.4=hb9d3cd8_1002
226
+ - pycparser=2.22=pyhd8ed1ab_0
227
+ - pysocks=1.7.1=pyha2e5f31_6
228
+ - python=3.11.10=hc5c86c4_3_cpython
229
+ - python_abi=3.11=5_cp311
230
+ - pytorch=2.5.1=py3.11_cuda12.1_cudnn9.1.0_0
231
+ - pytorch-cuda=12.1=ha16c6d3_6
232
+ - pytorch-mutex=1.0=cuda
233
+ - pywavelets=1.7.0=py311h9f3472d_2
234
+ - pyyaml=6.0.2=py311h9ecbd09_1
235
+ - rav1e=0.6.6=he8a937b_2
236
+ - readline=8.2=h8228510_1
237
+ - requests=2.32.3=pyhd8ed1ab_0
238
+ - scikit-image=0.24.0=py311h7db5c69_3
239
+ - scipy=1.14.1=py311he9a78e4_1
240
+ - setuptools=75.3.0=pyhd8ed1ab_0
241
+ - snappy=1.2.1=ha2e4443_0
242
+ - svt-av1=1.7.0=h59595ed_0
243
+ - sysroot_linux-64=2.17=h4a8ded7_18
244
+ - tifffile=2024.9.20=pyhd8ed1ab_0
245
+ - tk=8.6.13=noxft_h4845f30_101
246
+ - torchaudio=2.5.1=py311_cu121
247
+ - torchtriton=3.1.0=py311
248
+ - torchvision=0.20.1=py311_cu121
249
+ - typing_extensions=4.12.2=pyha770c72_0
250
+ - urllib3=2.2.3=pyhd8ed1ab_0
251
+ - wayland=1.23.1=h3e06ad9_0
252
+ - wayland-protocols=1.37=hd8ed1ab_0
253
+ - wheel=0.45.0=pyhd8ed1ab_0
254
+ - x264=1!161.3030=h7f98852_1
255
+ - x265=3.5=h924138e_3
256
+ - xkeyboard-config=2.43=hb9d3cd8_0
257
+ - xorg-libx11=1.8.10=h4f16b4b_0
258
+ - xorg-libxau=1.0.11=hb9d3cd8_1
259
+ - xorg-libxdmcp=1.1.5=hb9d3cd8_0
260
+ - xorg-libxext=1.3.6=hb9d3cd8_0
261
+ - xorg-libxfixes=6.0.1=hb9d3cd8_0
262
+ - xorg-xorgproto=2024.1=hb9d3cd8_1
263
+ - xz=5.2.6=h166bdaf_0
264
+ - yaml=0.2.5=h7f98852_2
265
+ - zfp=1.0.1=h5888daf_2
266
+ - zipp=3.20.2=pyhd8ed1ab_0
267
+ - zlib-ng=2.0.7=h0b41bf4_0
268
+ - zstandard=0.23.0=py311hbc35293_1
269
+ - zstd=1.5.6=ha6fb4c9_0
270
+ - pip:
271
+ - accelerate==1.1.1
272
+ - aiofiles==23.2.1
273
+ - albucore==0.0.13
274
+ - albumentations==1.4.10
275
+ - annotated-types==0.7.0
276
+ - anyio==3.7.1
277
+ - argon2-cffi==23.1.0
278
+ - argon2-cffi-bindings==21.2.0
279
+ - arrow==1.3.0
280
+ - astor==0.8.1
281
+ - asttokens==2.4.1
282
+ - async-lru==2.0.4
283
+ - attrs==24.2.0
284
+ - azure-core==1.32.0
285
+ - azure-identity==1.19.0
286
+ - babel==2.16.0
287
+ - beautifulsoup4==4.12.3
288
+ - bleach==6.2.0
289
+ - click==8.1.7
290
+ - comm==0.2.2
291
+ - contourpy==1.3.0
292
+ - cryptography==43.0.3
293
+ - cycler==0.12.1
294
+ - cython==3.0.11
295
+ - debugpy==1.8.8
296
+ - decorator==5.1.1
297
+ - defusedxml==0.7.1
298
+ - dill==0.3.9
299
+ - distro==1.9.0
300
+ - easyocr==1.7.2
301
+ - einops==0.8.0
302
+ - executing==2.1.0
303
+ - fastapi==0.115.4
304
+ - fastjsonschema==2.20.0
305
+ - ffmpy==0.4.0
306
+ - fire==0.7.0
307
+ - flash-attn==2.6.3
308
+ - fonttools==4.54.1
309
+ - fqdn==1.5.1
310
+ - fsspec==2024.10.0
311
+ - gradio==5.5.0
312
+ - gradio-client==1.4.2
313
+ - h11==0.14.0
314
+ - httpcore==1.0.6
315
+ - httpx==0.27.2
316
+ - huggingface-hub==0.26.2
317
+ - imgaug==0.4.0
318
+ - ipykernel==6.29.5
319
+ - ipython==8.29.0
320
+ - isoduration==20.11.0
321
+ - jedi==0.19.1
322
+ - joblib==1.4.2
323
+ - json5==0.9.25
324
+ - jsonpointer==3.0.0
325
+ - jsonschema==4.23.0
326
+ - jsonschema-specifications==2024.10.1
327
+ - jupyter-client==8.6.3
328
+ - jupyter-core==5.7.2
329
+ - jupyter-events==0.10.0
330
+ - jupyter-lsp==2.2.5
331
+ - jupyter-server==2.14.2
332
+ - jupyter-server-terminals==0.5.3
333
+ - jupyterlab==4.3.0
334
+ - jupyterlab-pygments==0.3.0
335
+ - jupyterlab-server==2.27.3
336
+ - kiwisolver==1.4.7
337
+ - lmdb==1.5.1
338
+ - lxml==5.3.0
339
+ - markdown-it-py==3.0.0
340
+ - markupsafe==2.1.5
341
+ - matplotlib==3.9.2
342
+ - matplotlib-inline==0.1.7
343
+ - mdurl==0.1.2
344
+ - mistune==3.0.2
345
+ - msal==1.31.0
346
+ - msal-extensions==1.2.0
347
+ - nbclient==0.10.0
348
+ - nbconvert==7.16.4
349
+ - nbformat==5.10.4
350
+ - nest-asyncio==1.6.0
351
+ - ninja==1.11.1.1
352
+ - notebook-shim==0.2.4
353
+ - openai==1.3.5
354
+ - opencv-contrib-python==4.10.0.84
355
+ - opencv-python==4.10.0.84
356
+ - opencv-python-headless==4.10.0.84
357
+ - opt-einsum==3.3.0
358
+ - orjson==3.10.11
359
+ - overrides==7.7.0
360
+ - paddleocr==2.9.1
361
+ - paddlepaddle==2.6.2
362
+ - pandas==2.2.3
363
+ - pandocfilters==1.5.1
364
+ - parso==0.8.4
365
+ - pexpect==4.9.0
366
+ - platformdirs==4.3.6
367
+ - portalocker==2.10.1
368
+ - prometheus-client==0.21.0
369
+ - prompt-toolkit==3.0.48
370
+ - protobuf==5.28.3
371
+ - psutil==6.1.0
372
+ - ptyprocess==0.7.0
373
+ - pure-eval==0.2.3
374
+ - py-cpuinfo==9.0.0
375
+ - pyclipper==1.3.0.post6
376
+ - pydantic==2.9.2
377
+ - pydantic-core==2.23.4
378
+ - pydub==0.25.1
379
+ - pygments==2.18.0
380
+ - pyjwt==2.9.0
381
+ - pyparsing==3.2.0
382
+ - python-bidi==0.6.3
383
+ - python-dateutil==2.9.0.post0
384
+ - python-docx==1.1.2
385
+ - python-json-logger==2.0.7
386
+ - python-multipart==0.0.12
387
+ - pytz==2024.2
388
+ - pyzmq==26.2.0
389
+ - rapidfuzz==3.10.1
390
+ - referencing==0.35.1
391
+ - regex==2024.11.6
392
+ - rfc3339-validator==0.1.4
393
+ - rfc3986-validator==0.1.1
394
+ - rich==13.9.4
395
+ - rpds-py==0.21.0
396
+ - ruff==0.7.3
397
+ - safehttpx==0.1.1
398
+ - safetensors==0.4.5
399
+ - scikit-learn==1.5.2
400
+ - seaborn==0.13.2
401
+ - semantic-version==2.10.0
402
+ - send2trash==1.8.3
403
+ - shapely==2.0.6
404
+ - shellingham==1.5.4
405
+ - six==1.16.0
406
+ - sniffio==1.3.1
407
+ - soupsieve==2.6
408
+ - stack-data==0.6.3
409
+ - starlette==0.41.2
410
+ - supervision==0.18.0
411
+ - sympy==1.13.1
412
+ - termcolor==2.5.0
413
+ - terminado==0.18.1
414
+ - thop==0.1.1-2209072238
415
+ - threadpoolctl==3.5.0
416
+ - timm==1.0.11
417
+ - tinycss2==1.4.0
418
+ - tokenizers==0.20.3
419
+ - tomli==2.0.2
420
+ - tomlkit==0.12.0
421
+ - tornado==6.4.1
422
+ - tqdm==4.67.0
423
+ - traitlets==5.14.3
424
+ - transformers==4.46.2
425
+ - typer==0.13.0
426
+ - types-python-dateutil==2.9.0.20241003
427
+ - tzdata==2024.2
428
+ - ultralytics==8.1.24
429
+ - uri-template==1.3.0
430
+ - uvicorn==0.32.0
431
+ - wcwidth==0.2.13
432
+ - webcolors==24.8.0
433
+ - webencodings==0.5.1
434
+ - websocket-client==1.8.0
435
+ - websockets==12.0
main.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ #from florence import model_init, draw_image
3
+ #from wikai import analyze_dial, ocr_and_od
4
+ import matplotlib.pyplot as plt
5
+ from aimodel import model_init, read_meter, visualization
6
+ from test_rect import read_meter as read_meter_rect
7
+ from PIL import Image
8
+ import logging
9
+ #logging.basicConfig(level=logging.DEBUG)
10
+ print("Loading model...")
11
+ fl, fl_ft = model_init(hack=False)
12
+ def process_image(input_image:Image, meter_type:str):
13
+ if meter_type == "方形儀表":
14
+ value, img = read_meter_rect(input_image, fl, fl_ft)
15
+ return img, f"辨識結果: PA={value}", None
16
+ assert meter_type == "圓形儀表"
17
+ plt.clf()
18
+ print("process_image")
19
+ W, H = 640, 480
20
+ if input_image is None:
21
+ return None, None
22
+ meter_result = read_meter(input_image, fl, fl_ft)
23
+ img, fig = visualization(meter_result)
24
+ return img, f"辨識結果: PSI={meter_result.needle_psi:.1f} kg/cm²={meter_result.needle_kg_cm2:.2f} ", plt
25
+
26
+ with gr.Blocks() as demo:
27
+
28
+ gr.Markdown("## 指針辨識系統\n請選擇儀表類型,上傳圖片,或點擊Submit")
29
+
30
+ with gr.Row():
31
+ with gr.Column():
32
+ with gr.Row():
33
+ clear_button = gr.ClearButton()
34
+ submit_button = gr.Button("Submit", variant="primary")
35
+ meter_type_dropdown = gr.Dropdown(choices=["圓形儀表", "方形儀表"], label="選擇選項")
36
+ image_input = gr.Image(type="pil", label="上傳圖片")
37
+ with gr.Column():
38
+ number_output = gr.Textbox(label="辨識結果", placeholder="辨識結果")
39
+ image_output = gr.Image(label="輸出圖片")
40
+ plot_output = gr.Plot(label="模型結果")
41
+
42
+ clear_button.add([image_input, image_output, number_output])
43
+
44
+ image_input.upload(
45
+ fn=process_image,
46
+ inputs=[image_input, meter_type_dropdown],
47
+ outputs=[image_output, number_output, plot_output],
48
+ queue=False
49
+ )
50
+
51
+ submit_button.click(
52
+ fn=process_image,
53
+ inputs=[image_input, meter_type_dropdown],
54
+ outputs=[image_output, number_output, plot_output],
55
+ )
56
+
57
+ demo.launch(debug=True)
readme.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ use `environment.yml` to recreate the conda environment.
2
+
3
+ use `python main.py` or `gradio main.py` to run the example.
4
+
5
+ images/good collects good images
test_rect.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import spaces
3
+ import matplotlib.style
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
+ from PIL import Image
6
+ import torch
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ from PIL import ImageDraw
10
+ from IPython.display import display
11
+ import numpy as np
12
+ from collections import namedtuple
13
+ import sys
14
+ print(sys.version_info)
15
+ #%%
16
+ class Florence:
17
+ def __init__(self, model_id:str, hack=False):
18
+ if hack:
19
+ return
20
+ self.model = (
21
+ AutoModelForCausalLM.from_pretrained(
22
+ model_id, trust_remote_code=True, torch_dtype="auto"
23
+ )
24
+ .eval()
25
+ .cuda()
26
+ )
27
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
28
+ self.model_id = model_id
29
+ def run(self, img:Image, task_prompt:str, extra_text:str|None=None):
30
+ model, processor = self.model, self.processor
31
+ prompt = task_prompt + (extra_text if extra_text else "")
32
+ inputs = processor(text=prompt, images=img, return_tensors="pt").to(
33
+ "cuda", torch.float16
34
+ )
35
+ generated_ids = model.generate(
36
+ input_ids=inputs["input_ids"],
37
+ pixel_values=inputs["pixel_values"],
38
+ max_new_tokens=1024,
39
+ early_stopping=False,
40
+ do_sample=False,
41
+ num_beams=3,
42
+ )
43
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
44
+ parsed_answer = processor.post_process_generation(
45
+ generated_text,
46
+ task=task_prompt,
47
+ image_size=(img.width, img.height),
48
+ )
49
+ return parsed_answer
50
+ def model_init():
51
+ fl = Florence("microsoft/Florence-2-large", hack=False)
52
+ fl_ft = Florence("microsoft/Florence-2-large-ft", hack=False)
53
+ return fl, fl_ft
54
+ # florence-2 tasks
55
+ TASK_OD = "<OD>"
56
+ TASK_SEGMENTATION = '<REFERRING_EXPRESSION_SEGMENTATION>'
57
+ TASK_CAPTION = "<CAPTION_TO_PHRASE_GROUNDING>"
58
+ TASK_OCR = "<OCR_WITH_REGION>"
59
+ TASK_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
60
+
61
+ #%%
62
+ from skimage.measure import LineModelND, ransac
63
+ def get_polygons(fl:Florence, img2:Image, prompt):
64
+ parsed_answer = fl.run(img2, TASK_SEGMENTATION, prompt)
65
+ assert len(parsed_answer) == 1
66
+ k,v = parsed_answer.popitem()
67
+ assert 'polygons' in v
68
+ assert len(v['polygons']) == 1
69
+ polygons = v['polygons'][0]
70
+ return polygons
71
+
72
+ def get_ocr(fl:Florence, img2:Image):
73
+ parsed_answer = fl.run(img2, TASK_OCR)
74
+ assert len(parsed_answer)==1
75
+ k,v = parsed_answer.popitem()
76
+ return v
77
+ imgs = list(Path('images/other').glob('*.jpg'))
78
+ meter_labels = list(map(str, range(0, 600, 100)))
79
+
80
+ def read_meter(img, fl:Florence, fl_ft:Florence):
81
+ if isinstance(img, str) or isinstance(img, Path):
82
+ print(img)
83
+ img = Image.open(img)
84
+ red_polygons = get_polygons(fl, img, 'red triangle pointer')
85
+ # draw the rectangle
86
+ draw = ImageDraw.Draw(img)
87
+ ocr_text = {}
88
+ ocr1 = get_ocr(fl, img)
89
+ ocr2 = get_ocr(fl_ft, img)
90
+ quad_boxes = ocr1['quad_boxes']+ocr2['quad_boxes']
91
+ labels = ocr1['labels']+ocr2['labels']
92
+ for quad_box, label in zip(quad_boxes, labels):
93
+ if label in meter_labels:
94
+ ocr_text[int(label)] = quad_box
95
+ for label, quad_box in ocr_text.items():
96
+ draw.polygon(quad_box, outline='green', width=3)
97
+ draw.text((quad_box[0], quad_box[1]-10), str(label), fill='green', anchor='ls')
98
+ text_centers = np.array(list(ocr_text.values())).reshape(-1, 4, 2).mean(axis=1)
99
+ lm = LineModelND()
100
+ lm.estimate(text_centers)
101
+ orign, direction = lm.params
102
+ # project text centers to the line
103
+ text_centers_shifted = text_centers - orign
104
+ text_centers_norm = text_centers_shifted @ direction
105
+ lm2 = LineModelND()
106
+ I = np.array(list(ocr_text.keys()))
107
+ L = text_centers_norm
108
+ data = np.stack([I, L], axis=1)
109
+ lm2.estimate(data)
110
+ ls = lm2.predict(list(range(0, 600, 100)))[:, 1]
111
+ x0, y0 = ls[0] * direction + orign
112
+ x1, y1 = ls[-1] * direction + orign
113
+ draw.line((x0, y0, x1, y1), fill='yellow', width=3)
114
+ for l in ls:
115
+ x, y = l * direction + orign
116
+ draw.ellipse((x-5, y-5, x+5, y+5), outline='yellow', width=3)
117
+ red_coords = np.concatenate(red_polygons).reshape(-1, 2)
118
+ red_shifted = red_coords - orign
119
+ red_norm = red_shifted @ direction
120
+ red_l = red_norm.mean()
121
+ red_i = np.clip(lm2.predict_x([red_l]), 0, 500)
122
+ red_l = lm2.predict_y(red_i)[0]
123
+ red_center = red_l * direction + orign
124
+ draw.ellipse((red_center[0]-5, red_center[1]-5, red_center[0]+5, red_center[1]+5), outline='red', width=3)
125
+ return red_i[0], img
126
+
127
+
128
+
129
+ @spaces.GPU
130
+ def main():
131
+ fl, fl_ft = model_init()
132
+ for img_fn in imgs:
133
+ print(img_fn)
134
+ img = Image.open(img_fn)
135
+ red_i, img2 = read_meater(img, fl, fl_ft)
136
+ print(red_i)
137
+ display(img2)
138
+ if __name__ == '__main__':
139
+ main()
140
+
141
+ #%%