llzzyy233 commited on
Commit
a80c25e
1 Parent(s): 85fd8d2

去掉柱状图

Browse files
Files changed (1) hide show
  1. app.py +70 -71
app.py CHANGED
@@ -1,71 +1,70 @@
1
- import gradio as gr
2
- import torch
3
- from PIL import Image
4
- from ultralytics import YOLO
5
- import matplotlib.pyplot as plt
6
- import io
7
- from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
- model = YOLO('detect-best.pt')
9
-
10
- def predict(img, conf, iou):
11
- results = model.predict(img, conf=conf, iou=iou)
12
- name = results[0].names
13
- cls = results[0].boxes.cls
14
- crazing = 0
15
- inclusion = 0
16
- patches = 0
17
- pitted_surface = 0
18
- rolled_inscale = 0
19
- scratches = 0
20
- for i in cls:
21
- if i == 0:
22
- crazing += 1
23
- elif i == 1:
24
- inclusion += 1
25
- elif i == 2:
26
- patches += 1
27
- elif i == 3:
28
- pitted_surface += 1
29
- elif i == 4:
30
- rolled_inscale += 1
31
- elif i == 5:
32
- scratches += 1
33
- # 绘制柱状图
34
- fig, ax = plt.subplots()
35
- categories = ['crazing','inclusion', 'patches' ,'pitted_surface', 'rolled_inscale' ,'scratches']
36
- counts = [crazing,inclusion, patches ,pitted_surface, rolled_inscale ,scratches]
37
- ax.bar(categories, counts)
38
- ax.set_title('Category-Count')
39
- plt.ylim(0,5)
40
- plt.xticks(rotation=45, ha="right")
41
- ax.set_xlabel('Category')
42
- ax.set_ylabel('Count')
43
- # 将图表保存为字节流
44
- buf = io.BytesIO()
45
- canvas = FigureCanvas(fig)
46
- canvas.print_png(buf)
47
- plt.close(fig) # 关闭图形,释放资源
48
-
49
- # 将字节流转换为PIL Image
50
- image_png = Image.open(buf)
51
- # 绘制并返回结果图片和类别计数图表
52
-
53
- for i, r in enumerate(results):
54
- # Plot results image
55
- im_bgr = r.plot() # BGR-order numpy array
56
- im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image
57
-
58
- # Show results to screen (in supported environments)
59
- return im_rgb, image_png
60
-
61
-
62
- base_conf, base_iou = 0.25, 0.45
63
- title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统"
64
- des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测"
65
- interface = gr.Interface(
66
- inputs=['image', gr.Slider(maximum=1, minimum=0, value=base_conf), gr.Slider(maximum=1, minimum=0, value=base_iou)],
67
- outputs=["image", 'image'], fn=predict, title=title, description=des,
68
- examples=[["example1.jpg", base_conf, base_iou],
69
- ["example2.jpg", base_conf, base_iou],
70
- ["example3.jpg", base_conf, base_iou]])
71
- interface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from ultralytics import YOLO
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
+ model = YOLO('detect-best.pt')
9
+
10
+ def predict(img, conf, iou):
11
+ results = model.predict(img, conf=conf, iou=iou)
12
+ name = results[0].names
13
+ cls = results[0].boxes.cls
14
+ crazing = 0
15
+ inclusion = 0
16
+ patches = 0
17
+ pitted_surface = 0
18
+ rolled_inscale = 0
19
+ scratches = 0
20
+ for i in cls:
21
+ if i == 0:
22
+ crazing += 1
23
+ elif i == 1:
24
+ inclusion += 1
25
+ elif i == 2:
26
+ patches += 1
27
+ elif i == 3:
28
+ pitted_surface += 1
29
+ elif i == 4:
30
+ rolled_inscale += 1
31
+ elif i == 5:
32
+ scratches += 1
33
+ # 绘制柱状图
34
+ fig, ax = plt.subplots()
35
+ categories = ['crazing','inclusion', 'patches' ,'pitted_surface', 'rolled_inscale' ,'scratches']
36
+ counts = [crazing,inclusion, patches ,pitted_surface, rolled_inscale ,scratches]
37
+ ax.bar(categories, counts)
38
+ ax.set_title('Category-Count')
39
+ plt.ylim(0,5)
40
+ plt.xticks(rotation=45, ha="right")
41
+ ax.set_xlabel('Category')
42
+ ax.set_ylabel('Count')
43
+ # 将图表保存为字节流
44
+ buf = io.BytesIO()
45
+ canvas = FigureCanvas(fig)
46
+ canvas.print_png(buf)
47
+ plt.close(fig) # 关闭图形,释放资源
48
+
49
+ # 将字节流转换为PIL Image
50
+ image_png = Image.open(buf)
51
+ # 绘制并返回结果图片和类别计数图表
52
+
53
+ for i, r in enumerate(results):
54
+ # Plot results image
55
+ im_bgr = r.plot() # BGR-order numpy array
56
+ im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image
57
+
58
+ # Show results to screen (in supported environments)
59
+ return im_rgb
60
+
61
+ base_conf, base_iou = 0.25, 0.45
62
+ title = "基于改进YOLOv8算法的工业瑕疵辅助检测系统"
63
+ des = "鼠标点击上传图片即可检测缺陷,可通过鼠标调整预测置信度,还可点击网页最下方示例图片进行预测"
64
+ interface = gr.Interface(
65
+ inputs=['image', gr.Slider(maximum=1, minimum=0, value=base_conf), gr.Slider(maximum=1, minimum=0, value=base_iou)],
66
+ outputs=["image"], fn=predict, title=title, description=des,
67
+ examples=[["example1.jpg", base_conf, base_iou],
68
+ ["example2.jpg", base_conf, base_iou],
69
+ ["example3.jpg", base_conf, base_iou]])
70
+ interface.launch()