Spaces:
Runtime error
Runtime error
jharrison27
commited on
Commit
•
67f0a55
1
Parent(s):
fcb5f3a
Upload 13 files
Browse files- .gitattributes +1 -0
- app.py +21 -0
- class_names/bccd_classes.txt +3 -0
- class_names/coco_classes.txt +80 -0
- coco_classes.txt +80 -0
- config.py +17 -0
- custom_callbacks.py +15 -0
- custom_layers.py +298 -0
- loss.py +212 -0
- models.py +530 -0
- requirements.txt +96 -0
- utils.py +475 -0
- xml_to_txt.py +42 -0
- yolov4.weights +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
yolov4.weights filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from glob import glob
|
5 |
+
from models import Yolov4
|
6 |
+
import gradio as gr
|
7 |
+
model = Yolov4(weight_path="yolov4.weights", class_name_path='coco_classes.txt')
|
8 |
+
def gradio_wrapper(img):
|
9 |
+
global model
|
10 |
+
#print(np.shape(img))
|
11 |
+
results = model.predict(img)
|
12 |
+
return results[0]
|
13 |
+
demo = gr.Interface(
|
14 |
+
gradio_wrapper,
|
15 |
+
#gr.Image(source="webcam", streaming=True, flip=True),
|
16 |
+
gr.Image(source="webcam", streaming=True),
|
17 |
+
"image",
|
18 |
+
live=True
|
19 |
+
)
|
20 |
+
|
21 |
+
demo.launch()
|
class_names/bccd_classes.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
WBC
|
2 |
+
Platelets
|
3 |
+
RBC
|
class_names/coco_classes.txt
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
person
|
2 |
+
bicycle
|
3 |
+
car
|
4 |
+
motorbike
|
5 |
+
aeroplane
|
6 |
+
bus
|
7 |
+
train
|
8 |
+
truck
|
9 |
+
boat
|
10 |
+
traffic light
|
11 |
+
fire hydrant
|
12 |
+
stop sign
|
13 |
+
parking meter
|
14 |
+
bench
|
15 |
+
bird
|
16 |
+
cat
|
17 |
+
dog
|
18 |
+
horse
|
19 |
+
sheep
|
20 |
+
cow
|
21 |
+
elephant
|
22 |
+
bear
|
23 |
+
zebra
|
24 |
+
giraffe
|
25 |
+
backpack
|
26 |
+
umbrella
|
27 |
+
handbag
|
28 |
+
tie
|
29 |
+
suitcase
|
30 |
+
frisbee
|
31 |
+
skis
|
32 |
+
snowboard
|
33 |
+
sports ball
|
34 |
+
kite
|
35 |
+
baseball bat
|
36 |
+
baseball glove
|
37 |
+
skateboard
|
38 |
+
surfboard
|
39 |
+
tennis racket
|
40 |
+
bottle
|
41 |
+
wine glass
|
42 |
+
cup
|
43 |
+
fork
|
44 |
+
knife
|
45 |
+
spoon
|
46 |
+
bowl
|
47 |
+
banana
|
48 |
+
apple
|
49 |
+
sandwich
|
50 |
+
orange
|
51 |
+
broccoli
|
52 |
+
carrot
|
53 |
+
hot dog
|
54 |
+
pizza
|
55 |
+
donut
|
56 |
+
cake
|
57 |
+
chair
|
58 |
+
sofa
|
59 |
+
pottedplant
|
60 |
+
bed
|
61 |
+
diningtable
|
62 |
+
toilet
|
63 |
+
tvmonitor
|
64 |
+
laptop
|
65 |
+
mouse
|
66 |
+
remote
|
67 |
+
keyboard
|
68 |
+
cell phone
|
69 |
+
microwave
|
70 |
+
oven
|
71 |
+
toaster
|
72 |
+
sink
|
73 |
+
refrigerator
|
74 |
+
book
|
75 |
+
clock
|
76 |
+
vase
|
77 |
+
scissors
|
78 |
+
teddy bear
|
79 |
+
hair drier
|
80 |
+
toothbrush
|
coco_classes.txt
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
person
|
2 |
+
bicycle
|
3 |
+
car
|
4 |
+
motorbike
|
5 |
+
aeroplane
|
6 |
+
bus
|
7 |
+
train
|
8 |
+
truck
|
9 |
+
boat
|
10 |
+
traffic light
|
11 |
+
fire hydrant
|
12 |
+
stop sign
|
13 |
+
parking meter
|
14 |
+
bench
|
15 |
+
bird
|
16 |
+
cat
|
17 |
+
dog
|
18 |
+
horse
|
19 |
+
sheep
|
20 |
+
cow
|
21 |
+
elephant
|
22 |
+
bear
|
23 |
+
zebra
|
24 |
+
giraffe
|
25 |
+
backpack
|
26 |
+
umbrella
|
27 |
+
handbag
|
28 |
+
tie
|
29 |
+
suitcase
|
30 |
+
frisbee
|
31 |
+
skis
|
32 |
+
snowboard
|
33 |
+
sports ball
|
34 |
+
kite
|
35 |
+
baseball bat
|
36 |
+
baseball glove
|
37 |
+
skateboard
|
38 |
+
surfboard
|
39 |
+
tennis racket
|
40 |
+
bottle
|
41 |
+
wine glass
|
42 |
+
cup
|
43 |
+
fork
|
44 |
+
knife
|
45 |
+
spoon
|
46 |
+
bowl
|
47 |
+
banana
|
48 |
+
apple
|
49 |
+
sandwich
|
50 |
+
orange
|
51 |
+
broccoli
|
52 |
+
carrot
|
53 |
+
hot dog
|
54 |
+
pizza
|
55 |
+
donut
|
56 |
+
cake
|
57 |
+
chair
|
58 |
+
sofa
|
59 |
+
pottedplant
|
60 |
+
bed
|
61 |
+
diningtable
|
62 |
+
toilet
|
63 |
+
tvmonitor
|
64 |
+
laptop
|
65 |
+
mouse
|
66 |
+
remote
|
67 |
+
keyboard
|
68 |
+
cell phone
|
69 |
+
microwave
|
70 |
+
oven
|
71 |
+
toaster
|
72 |
+
sink
|
73 |
+
refrigerator
|
74 |
+
book
|
75 |
+
clock
|
76 |
+
vase
|
77 |
+
scissors
|
78 |
+
teddy bear
|
79 |
+
hair drier
|
80 |
+
toothbrush
|
config.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
yolo_config = {
|
2 |
+
# Basic
|
3 |
+
'img_size': (416, 416, 3),
|
4 |
+
'anchors': [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],
|
5 |
+
'strides': [8, 16, 32],
|
6 |
+
'xyscale': [1.2, 1.1, 1.05],
|
7 |
+
|
8 |
+
# Training
|
9 |
+
'iou_loss_thresh': 0.5,
|
10 |
+
'batch_size': 8,
|
11 |
+
'num_gpu': 1, # 2,
|
12 |
+
|
13 |
+
# Inference
|
14 |
+
'max_boxes': 100,
|
15 |
+
'iou_threshold': 0.413,
|
16 |
+
'score_threshold': 0.3,
|
17 |
+
}
|
custom_callbacks.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow.keras import callbacks
|
2 |
+
import math
|
3 |
+
|
4 |
+
|
5 |
+
class CosineAnnealingScheduler(callbacks.LearningRateScheduler):
|
6 |
+
def __init__(self, epochs_per_cycle, lr_min, lr_max, verbose=0):
|
7 |
+
super(callbacks.LearningRateScheduler, self).__init__()
|
8 |
+
self.verbose = verbose
|
9 |
+
self.lr_min = lr_min
|
10 |
+
self.lr_max = lr_max
|
11 |
+
self.epochs_per_cycle = epochs_per_cycle
|
12 |
+
|
13 |
+
def schedule(self, epoch, lr):
|
14 |
+
return self.lr_min + (self.lr_max - self.lr_min) *\
|
15 |
+
(1 + math.cos(math.pi * (epoch % self.epochs_per_cycle) / self.epochs_per_cycle)) / 2
|
custom_layers.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers, initializers, models
|
3 |
+
|
4 |
+
|
5 |
+
def conv(x, filters, kernel_size, downsampling=False, activation='leaky', batch_norm=True):
|
6 |
+
def mish(x):
|
7 |
+
return x * tf.math.tanh(tf.math.softplus(x))
|
8 |
+
|
9 |
+
if downsampling:
|
10 |
+
x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(x) # top & left padding
|
11 |
+
padding = 'valid'
|
12 |
+
strides = 2
|
13 |
+
else:
|
14 |
+
padding = 'same'
|
15 |
+
strides = 1
|
16 |
+
x = layers.Conv2D(filters,
|
17 |
+
kernel_size,
|
18 |
+
strides=strides,
|
19 |
+
padding=padding,
|
20 |
+
use_bias=not batch_norm,
|
21 |
+
# kernel_regularizer=regularizers.l2(0.0005),
|
22 |
+
kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.01),
|
23 |
+
# bias_initializer=initializers.Zeros()
|
24 |
+
)(x)
|
25 |
+
if batch_norm:
|
26 |
+
x = layers.BatchNormalization()(x)
|
27 |
+
if activation == 'mish':
|
28 |
+
x = mish(x)
|
29 |
+
elif activation == 'leaky':
|
30 |
+
x = layers.LeakyReLU(alpha=0.1)(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def residual_block(x, filters1, filters2, activation='leaky'):
|
35 |
+
"""
|
36 |
+
:param x: input tensor
|
37 |
+
:param filters1: num of filter for 1x1 conv
|
38 |
+
:param filters2: num of filter for 3x3 conv
|
39 |
+
:param activation: default activation function: leaky relu
|
40 |
+
:return:
|
41 |
+
"""
|
42 |
+
y = conv(x, filters1, kernel_size=1, activation=activation)
|
43 |
+
y = conv(y, filters2, kernel_size=3, activation=activation)
|
44 |
+
return layers.Add()([x, y])
|
45 |
+
|
46 |
+
|
47 |
+
def csp_block(x, residual_out, repeat, residual_bottleneck=False):
|
48 |
+
"""
|
49 |
+
Cross Stage Partial Network (CSPNet)
|
50 |
+
transition_bottleneck_dims: 1x1 bottleneck
|
51 |
+
output_dims: 3x3
|
52 |
+
:param x:
|
53 |
+
:param residual_out:
|
54 |
+
:param repeat:
|
55 |
+
:param residual_bottleneck:
|
56 |
+
:return:
|
57 |
+
"""
|
58 |
+
route = x
|
59 |
+
route = conv(route, residual_out, 1, activation="mish")
|
60 |
+
x = conv(x, residual_out, 1, activation="mish")
|
61 |
+
for i in range(repeat):
|
62 |
+
x = residual_block(x,
|
63 |
+
residual_out // 2 if residual_bottleneck else residual_out,
|
64 |
+
residual_out,
|
65 |
+
activation="mish")
|
66 |
+
x = conv(x, residual_out, 1, activation="mish")
|
67 |
+
|
68 |
+
x = layers.Concatenate()([x, route])
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
def darknet53(x):
|
73 |
+
x = conv(x, 32, 3)
|
74 |
+
x = conv(x, 64, 3, downsampling=True)
|
75 |
+
|
76 |
+
for i in range(1):
|
77 |
+
x = residual_block(x, 32, 64)
|
78 |
+
x = conv(x, 128, 3, downsampling=True)
|
79 |
+
|
80 |
+
for i in range(2):
|
81 |
+
x = residual_block(x, 64, 128)
|
82 |
+
x = conv(x, 256, 3, downsampling=True)
|
83 |
+
|
84 |
+
for i in range(8):
|
85 |
+
x = residual_block(x, 128, 256)
|
86 |
+
route_1 = x
|
87 |
+
x = conv(x, 512, 3, downsampling=True)
|
88 |
+
|
89 |
+
for i in range(8):
|
90 |
+
x = residual_block(x, 256, 512)
|
91 |
+
route_2 = x
|
92 |
+
x = conv(x, 1024, 3, downsampling=True)
|
93 |
+
|
94 |
+
for i in range(4):
|
95 |
+
x = residual_block(x, 512, 1024)
|
96 |
+
|
97 |
+
return route_1, route_2, x
|
98 |
+
|
99 |
+
|
100 |
+
def cspdarknet53(input):
|
101 |
+
x = conv(input, 32, 3)
|
102 |
+
x = conv(x, 64, 3, downsampling=True)
|
103 |
+
|
104 |
+
x = csp_block(x, residual_out=64, repeat=1, residual_bottleneck=True)
|
105 |
+
x = conv(x, 64, 1, activation='mish')
|
106 |
+
x = conv(x, 128, 3, activation='mish', downsampling=True)
|
107 |
+
|
108 |
+
x = csp_block(x, residual_out=64, repeat=2)
|
109 |
+
x = conv(x, 128, 1, activation='mish')
|
110 |
+
x = conv(x, 256, 3, activation='mish', downsampling=True)
|
111 |
+
|
112 |
+
x = csp_block(x, residual_out=128, repeat=8)
|
113 |
+
x = conv(x, 256, 1, activation='mish')
|
114 |
+
route0 = x
|
115 |
+
x = conv(x, 512, 3, activation='mish', downsampling=True)
|
116 |
+
|
117 |
+
x = csp_block(x, residual_out=256, repeat=8)
|
118 |
+
x = conv(x, 512, 1, activation='mish')
|
119 |
+
route1 = x
|
120 |
+
x = conv(x, 1024, 3, activation='mish', downsampling=True)
|
121 |
+
|
122 |
+
x = csp_block(x, residual_out=512, repeat=4)
|
123 |
+
|
124 |
+
x = conv(x, 1024, 1, activation="mish")
|
125 |
+
|
126 |
+
x = conv(x, 512, 1)
|
127 |
+
x = conv(x, 1024, 3)
|
128 |
+
x = conv(x, 512, 1)
|
129 |
+
|
130 |
+
x = layers.Concatenate()([layers.MaxPooling2D(pool_size=13, strides=1, padding='same')(x),
|
131 |
+
layers.MaxPooling2D(pool_size=9, strides=1, padding='same')(x),
|
132 |
+
layers.MaxPooling2D(pool_size=5, strides=1, padding='same')(x),
|
133 |
+
x
|
134 |
+
])
|
135 |
+
x = conv(x, 512, 1)
|
136 |
+
x = conv(x, 1024, 3)
|
137 |
+
route2 = conv(x, 512, 1)
|
138 |
+
return models.Model(input, [route0, route1, route2])
|
139 |
+
|
140 |
+
|
141 |
+
def yolov4_neck(x, num_classes):
|
142 |
+
backbone_model = cspdarknet53(x)
|
143 |
+
route0, route1, route2 = backbone_model.output
|
144 |
+
|
145 |
+
route_input = route2
|
146 |
+
x = conv(route2, 256, 1)
|
147 |
+
x = layers.UpSampling2D()(x)
|
148 |
+
route1 = conv(route1, 256, 1)
|
149 |
+
x = layers.Concatenate()([route1, x])
|
150 |
+
|
151 |
+
x = conv(x, 256, 1)
|
152 |
+
x = conv(x, 512, 3)
|
153 |
+
x = conv(x, 256, 1)
|
154 |
+
x = conv(x, 512, 3)
|
155 |
+
x = conv(x, 256, 1)
|
156 |
+
|
157 |
+
route1 = x
|
158 |
+
x = conv(x, 128, 1)
|
159 |
+
x = layers.UpSampling2D()(x)
|
160 |
+
route0 = conv(route0, 128, 1)
|
161 |
+
x = layers.Concatenate()([route0, x])
|
162 |
+
|
163 |
+
x = conv(x, 128, 1)
|
164 |
+
x = conv(x, 256, 3)
|
165 |
+
x = conv(x, 128, 1)
|
166 |
+
x = conv(x, 256, 3)
|
167 |
+
x = conv(x, 128, 1)
|
168 |
+
|
169 |
+
route0 = x
|
170 |
+
x = conv(x, 256, 3)
|
171 |
+
conv_sbbox = conv(x, 3 * (num_classes + 5), 1, activation=None, batch_norm=False)
|
172 |
+
|
173 |
+
x = conv(route0, 256, 3, downsampling=True)
|
174 |
+
x = layers.Concatenate()([x, route1])
|
175 |
+
|
176 |
+
x = conv(x, 256, 1)
|
177 |
+
x = conv(x, 512, 3)
|
178 |
+
x = conv(x, 256, 1)
|
179 |
+
x = conv(x, 512, 3)
|
180 |
+
x = conv(x, 256, 1)
|
181 |
+
|
182 |
+
route1 = x
|
183 |
+
x = conv(x, 512, 3)
|
184 |
+
conv_mbbox = conv(x, 3 * (num_classes + 5), 1, activation=None, batch_norm=False)
|
185 |
+
|
186 |
+
x = conv(route1, 512, 3, downsampling=True)
|
187 |
+
x = layers.Concatenate()([x, route_input])
|
188 |
+
|
189 |
+
x = conv(x, 512, 1)
|
190 |
+
x = conv(x, 1024, 3)
|
191 |
+
x = conv(x, 512, 1)
|
192 |
+
x = conv(x, 1024, 3)
|
193 |
+
x = conv(x, 512, 1)
|
194 |
+
|
195 |
+
x = conv(x, 1024, 3)
|
196 |
+
conv_lbbox = conv(x, 3 * (num_classes + 5), 1, activation=None, batch_norm=False)
|
197 |
+
|
198 |
+
return [conv_sbbox, conv_mbbox, conv_lbbox]
|
199 |
+
|
200 |
+
|
201 |
+
def yolov4_head(yolo_neck_outputs, classes, anchors, xyscale):
|
202 |
+
bbox0, object_probability0, class_probabilities0, pred_box0 = get_boxes(yolo_neck_outputs[0],
|
203 |
+
anchors=anchors[0, :, :], classes=classes,
|
204 |
+
grid_size=52, strides=8,
|
205 |
+
xyscale=xyscale[0])
|
206 |
+
bbox1, object_probability1, class_probabilities1, pred_box1 = get_boxes(yolo_neck_outputs[1],
|
207 |
+
anchors=anchors[1, :, :], classes=classes,
|
208 |
+
grid_size=26, strides=16,
|
209 |
+
xyscale=xyscale[1])
|
210 |
+
bbox2, object_probability2, class_probabilities2, pred_box2 = get_boxes(yolo_neck_outputs[2],
|
211 |
+
anchors=anchors[2, :, :], classes=classes,
|
212 |
+
grid_size=13, strides=32,
|
213 |
+
xyscale=xyscale[2])
|
214 |
+
x = [bbox0, object_probability0, class_probabilities0, pred_box0,
|
215 |
+
bbox1, object_probability1, class_probabilities1, pred_box1,
|
216 |
+
bbox2, object_probability2, class_probabilities2, pred_box2]
|
217 |
+
|
218 |
+
return x
|
219 |
+
|
220 |
+
|
221 |
+
def get_boxes(pred, anchors, classes, grid_size, strides, xyscale):
|
222 |
+
"""
|
223 |
+
|
224 |
+
:param pred:
|
225 |
+
:param anchors:
|
226 |
+
:param classes:
|
227 |
+
:param grid_size:
|
228 |
+
:param strides:
|
229 |
+
:param xyscale:
|
230 |
+
:return:
|
231 |
+
"""
|
232 |
+
pred = tf.reshape(pred,
|
233 |
+
(tf.shape(pred)[0],
|
234 |
+
grid_size,
|
235 |
+
grid_size,
|
236 |
+
3,
|
237 |
+
5 + classes)) # (batch_size, grid_size, grid_size, 3, 5+classes)
|
238 |
+
box_xy, box_wh, obj_prob, class_prob = tf.split(
|
239 |
+
pred, (2, 2, 1, classes), axis=-1
|
240 |
+
) # (?, 52, 52, 3, 2) (?, 52, 52, 3, 2) (?, 52, 52, 3, 1) (?, 52, 52, 3, 80)
|
241 |
+
|
242 |
+
box_xy = tf.sigmoid(box_xy) # (?, 52, 52, 3, 2)
|
243 |
+
obj_prob = tf.sigmoid(obj_prob) # (?, 52, 52, 3, 1)
|
244 |
+
class_prob = tf.sigmoid(class_prob) # (?, 52, 52, 3, 80)
|
245 |
+
pred_box_xywh = tf.concat((box_xy, box_wh), axis=-1) # (?, 52, 52, 3, 4)
|
246 |
+
|
247 |
+
grid = tf.meshgrid(tf.range(grid_size), tf.range(grid_size)) # (52, 52) (52, 52)
|
248 |
+
grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2) # (52, 52, 1, 2)
|
249 |
+
grid = tf.cast(grid, dtype=tf.float32)
|
250 |
+
|
251 |
+
box_xy = ((box_xy * xyscale) - 0.5 * (xyscale - 1) + grid) * strides # (?, 52, 52, 1, 4)
|
252 |
+
|
253 |
+
box_wh = tf.exp(box_wh) * anchors # (?, 52, 52, 3, 2)
|
254 |
+
box_x1y1 = box_xy - box_wh / 2 # (?, 52, 52, 3, 2)
|
255 |
+
box_x2y2 = box_xy + box_wh / 2 # (?, 52, 52, 3, 2)
|
256 |
+
pred_box_x1y1x2y2 = tf.concat([box_x1y1, box_x2y2], axis=-1) # (?, 52, 52, 3, 4)
|
257 |
+
return pred_box_x1y1x2y2, obj_prob, class_prob, pred_box_xywh
|
258 |
+
# pred_box_x1y1x2y2: absolute xy value
|
259 |
+
|
260 |
+
|
261 |
+
def nms(model_ouputs, input_shape, num_class, iou_threshold=0.413, score_threshold=0.3):
|
262 |
+
"""
|
263 |
+
Apply Non-Maximum suppression
|
264 |
+
ref: https://www.tensorflow.org/api_docs/python/tf/image/combined_non_max_suppression
|
265 |
+
:param model_ouputs: yolo model model_ouputs
|
266 |
+
:param input_shape: size of input image
|
267 |
+
:return: nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
|
268 |
+
"""
|
269 |
+
bs = tf.shape(model_ouputs[0])[0]
|
270 |
+
boxes = tf.zeros((bs, 0, 4))
|
271 |
+
confidence = tf.zeros((bs, 0, 1))
|
272 |
+
class_probabilities = tf.zeros((bs, 0, num_class))
|
273 |
+
|
274 |
+
for output_idx in range(0, len(model_ouputs), 4):
|
275 |
+
output_xy = model_ouputs[output_idx]
|
276 |
+
output_conf = model_ouputs[output_idx + 1]
|
277 |
+
output_classes = model_ouputs[output_idx + 2]
|
278 |
+
boxes = tf.concat([boxes, tf.reshape(output_xy, (bs, -1, 4))], axis=1)
|
279 |
+
confidence = tf.concat([confidence, tf.reshape(output_conf, (bs, -1, 1))], axis=1)
|
280 |
+
class_probabilities = tf.concat([class_probabilities, tf.reshape(output_classes, (bs, -1, num_class))], axis=1)
|
281 |
+
|
282 |
+
scores = confidence * class_probabilities
|
283 |
+
boxes = tf.expand_dims(boxes, axis=-2)
|
284 |
+
boxes = boxes / input_shape[0] # box normalization: relative img size
|
285 |
+
print(f'nms iou: {iou_threshold} score: {score_threshold}')
|
286 |
+
(nmsed_boxes, # [bs, max_detections, 4]
|
287 |
+
nmsed_scores, # [bs, max_detections]
|
288 |
+
nmsed_classes, # [bs, max_detections]
|
289 |
+
valid_detections # [batch_size]
|
290 |
+
) = tf.image.combined_non_max_suppression(
|
291 |
+
boxes=boxes, # y1x1, y2x2 [0~1]
|
292 |
+
scores=scores,
|
293 |
+
max_output_size_per_class=100,
|
294 |
+
max_total_size=100, # max_boxes: Maximum nmsed_boxes in a single img.
|
295 |
+
iou_threshold=iou_threshold, # iou_threshold: Minimum overlap that counts as a valid detection.
|
296 |
+
score_threshold=score_threshold, # # Minimum confidence that counts as a valid detection.
|
297 |
+
)
|
298 |
+
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
|
loss.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import tensorflow.keras.backend as K
|
7 |
+
import tensorflow as tf
|
8 |
+
|
9 |
+
|
10 |
+
def xywh_to_x1y1x2y2(boxes):
|
11 |
+
return tf.concat([boxes[..., :2] - boxes[..., 2:] * 0.5, boxes[..., :2] + boxes[..., 2:] * 0.5], axis=-1)
|
12 |
+
|
13 |
+
|
14 |
+
# x,y,w,h
|
15 |
+
def bbox_iou(boxes1, boxes2):
|
16 |
+
boxes1_area = boxes1[..., 2] * boxes1[..., 3] # w * h
|
17 |
+
boxes2_area = boxes2[..., 2] * boxes2[..., 3]
|
18 |
+
|
19 |
+
# (x, y, w, h) -> (x0, y0, x1, y1)
|
20 |
+
boxes1 = xywh_to_x1y1x2y2(boxes1)
|
21 |
+
boxes2 = xywh_to_x1y1x2y2(boxes2)
|
22 |
+
|
23 |
+
# coordinates of intersection
|
24 |
+
top_left = tf.maximum(boxes1[..., :2], boxes2[..., :2])
|
25 |
+
bottom_right = tf.minimum(boxes1[..., 2:], boxes2[..., 2:])
|
26 |
+
intersection_xy = tf.maximum(bottom_right - top_left, 0.0)
|
27 |
+
|
28 |
+
intersection_area = intersection_xy[..., 0] * intersection_xy[..., 1]
|
29 |
+
union_area = boxes1_area + boxes2_area - intersection_area
|
30 |
+
|
31 |
+
return 1.0 * intersection_area / (union_area + tf.keras.backend.epsilon())
|
32 |
+
|
33 |
+
|
34 |
+
def bbox_giou(boxes1, boxes2):
|
35 |
+
boxes1_area = boxes1[..., 2] * boxes1[..., 3] # w*h
|
36 |
+
boxes2_area = boxes2[..., 2] * boxes2[..., 3]
|
37 |
+
|
38 |
+
# (x, y, w, h) -> (x0, y0, x1, y1)
|
39 |
+
boxes1 = xywh_to_x1y1x2y2(boxes1)
|
40 |
+
boxes2 = xywh_to_x1y1x2y2(boxes2)
|
41 |
+
|
42 |
+
top_left = tf.maximum(boxes1[..., :2], boxes2[..., :2])
|
43 |
+
bottom_right = tf.minimum(boxes1[..., 2:], boxes2[..., 2:])
|
44 |
+
|
45 |
+
intersection_xy = tf.maximum(bottom_right - top_left, 0.0)
|
46 |
+
intersection_area = intersection_xy[..., 0] * intersection_xy[..., 1]
|
47 |
+
|
48 |
+
union_area = boxes1_area + boxes2_area - intersection_area
|
49 |
+
|
50 |
+
iou = 1.0 * intersection_area / (union_area + tf.keras.backend.epsilon())
|
51 |
+
|
52 |
+
enclose_top_left = tf.minimum(boxes1[..., :2], boxes2[..., :2])
|
53 |
+
enclose_bottom_right = tf.maximum(boxes1[..., 2:], boxes2[..., 2:])
|
54 |
+
|
55 |
+
enclose_xy = enclose_bottom_right - enclose_top_left
|
56 |
+
enclose_area = enclose_xy[..., 0] * enclose_xy[..., 1]
|
57 |
+
|
58 |
+
giou = iou - tf.math.divide_no_nan(enclose_area - union_area, enclose_area)
|
59 |
+
|
60 |
+
return giou
|
61 |
+
|
62 |
+
|
63 |
+
def bbox_ciou(boxes1, boxes2):
|
64 |
+
'''
|
65 |
+
ciou = iou - p2/c2 - av
|
66 |
+
:param boxes1: (8, 13, 13, 3, 4) pred_xywh
|
67 |
+
:param boxes2: (8, 13, 13, 3, 4) label_xywh
|
68 |
+
:return:
|
69 |
+
'''
|
70 |
+
boxes1_x0y0x1y1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
|
71 |
+
boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
|
72 |
+
boxes2_x0y0x1y1 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
|
73 |
+
boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)
|
74 |
+
boxes1_x0y0x1y1 = tf.concat([tf.minimum(boxes1_x0y0x1y1[..., :2], boxes1_x0y0x1y1[..., 2:]),
|
75 |
+
tf.maximum(boxes1_x0y0x1y1[..., :2], boxes1_x0y0x1y1[..., 2:])], axis=-1)
|
76 |
+
boxes2_x0y0x1y1 = tf.concat([tf.minimum(boxes2_x0y0x1y1[..., :2], boxes2_x0y0x1y1[..., 2:]),
|
77 |
+
tf.maximum(boxes2_x0y0x1y1[..., :2], boxes2_x0y0x1y1[..., 2:])], axis=-1)
|
78 |
+
|
79 |
+
# area
|
80 |
+
boxes1_area = (boxes1_x0y0x1y1[..., 2] - boxes1_x0y0x1y1[..., 0]) * (
|
81 |
+
boxes1_x0y0x1y1[..., 3] - boxes1_x0y0x1y1[..., 1])
|
82 |
+
boxes2_area = (boxes2_x0y0x1y1[..., 2] - boxes2_x0y0x1y1[..., 0]) * (
|
83 |
+
boxes2_x0y0x1y1[..., 3] - boxes2_x0y0x1y1[..., 1])
|
84 |
+
|
85 |
+
# top-left and bottom-right coord, shape: (8, 13, 13, 3, 2)
|
86 |
+
left_up = tf.maximum(boxes1_x0y0x1y1[..., :2], boxes2_x0y0x1y1[..., :2])
|
87 |
+
right_down = tf.minimum(boxes1_x0y0x1y1[..., 2:], boxes2_x0y0x1y1[..., 2:])
|
88 |
+
|
89 |
+
# intersection area and iou
|
90 |
+
inter_section = tf.maximum(right_down - left_up, 0.0)
|
91 |
+
inter_area = inter_section[..., 0] * inter_section[..., 1]
|
92 |
+
union_area = boxes1_area + boxes2_area - inter_area
|
93 |
+
iou = inter_area / (union_area + 1e-9)
|
94 |
+
|
95 |
+
# top-left and bottom-right coord of the enclosing rectangle, shape: (8, 13, 13, 3, 2)
|
96 |
+
enclose_left_up = tf.minimum(boxes1_x0y0x1y1[..., :2], boxes2_x0y0x1y1[..., :2])
|
97 |
+
enclose_right_down = tf.maximum(boxes1_x0y0x1y1[..., 2:], boxes2_x0y0x1y1[..., 2:])
|
98 |
+
|
99 |
+
# diagnal ** 2
|
100 |
+
enclose_wh = enclose_right_down - enclose_left_up
|
101 |
+
enclose_c2 = K.pow(enclose_wh[..., 0], 2) + K.pow(enclose_wh[..., 1], 2)
|
102 |
+
|
103 |
+
# center distances between two rectangles
|
104 |
+
p2 = K.pow(boxes1[..., 0] - boxes2[..., 0], 2) + K.pow(boxes1[..., 1] - boxes2[..., 1], 2)
|
105 |
+
|
106 |
+
# add av
|
107 |
+
atan1 = tf.atan(boxes1[..., 2] / (boxes1[..., 3] + 1e-9))
|
108 |
+
atan2 = tf.atan(boxes2[..., 2] / (boxes2[..., 3] + 1e-9))
|
109 |
+
v = 4.0 * K.pow(atan1 - atan2, 2) / (math.pi ** 2)
|
110 |
+
a = v / (1 - iou + v)
|
111 |
+
|
112 |
+
ciou = iou - 1.0 * p2 / enclose_c2 - 1.0 * a * v
|
113 |
+
return ciou
|
114 |
+
|
115 |
+
|
116 |
+
def yolo_loss(args, num_classes, iou_loss_thresh, anchors):
|
117 |
+
conv_lbbox = args[2] # (?, ?, ?, 3*(num_classes+5))
|
118 |
+
conv_mbbox = args[1] # (?, ?, ?, 3*(num_classes+5))
|
119 |
+
conv_sbbox = args[0] # (?, ?, ?, 3*(num_classes+5))
|
120 |
+
label_sbbox = args[3] # (?, ?, ?, 3, num_classes+5)
|
121 |
+
label_mbbox = args[4] # (?, ?, ?, 3, num_classes+5)
|
122 |
+
label_lbbox = args[5] # (?, ?, ?, 3, num_classes+5)
|
123 |
+
true_bboxes = args[6] # (?, 50, 4)
|
124 |
+
pred_sbbox = decode(conv_sbbox, anchors[0], 8, num_classes)
|
125 |
+
pred_mbbox = decode(conv_mbbox, anchors[1], 16, num_classes)
|
126 |
+
pred_lbbox = decode(conv_lbbox, anchors[2], 32, num_classes)
|
127 |
+
sbbox_ciou_loss, sbbox_conf_loss, sbbox_prob_loss = loss_layer(conv_sbbox, pred_sbbox, label_sbbox, true_bboxes, 8, num_classes, iou_loss_thresh)
|
128 |
+
mbbox_ciou_loss, mbbox_conf_loss, mbbox_prob_loss = loss_layer(conv_mbbox, pred_mbbox, label_mbbox, true_bboxes, 16, num_classes, iou_loss_thresh)
|
129 |
+
lbbox_ciou_loss, lbbox_conf_loss, lbbox_prob_loss = loss_layer(conv_lbbox, pred_lbbox, label_lbbox, true_bboxes, 32, num_classes, iou_loss_thresh)
|
130 |
+
|
131 |
+
ciou_loss = (lbbox_ciou_loss + sbbox_ciou_loss + mbbox_ciou_loss) * 3.54
|
132 |
+
conf_loss = (lbbox_conf_loss + sbbox_conf_loss + mbbox_conf_loss) * 64.3
|
133 |
+
prob_loss = (lbbox_prob_loss + sbbox_prob_loss + mbbox_prob_loss) * 1
|
134 |
+
|
135 |
+
return ciou_loss+conf_loss+prob_loss
|
136 |
+
|
137 |
+
|
138 |
+
def loss_layer(conv, pred, label, bboxes, stride, num_class, iou_loss_thresh):
|
139 |
+
conv_shape = tf.shape(conv)
|
140 |
+
batch_size = conv_shape[0]
|
141 |
+
output_size = conv_shape[1]
|
142 |
+
input_size = stride * output_size
|
143 |
+
conv = tf.reshape(conv, (batch_size, output_size, output_size,
|
144 |
+
3, 5 + num_class))
|
145 |
+
conv_raw_prob = conv[:, :, :, :, 5:]
|
146 |
+
conv_raw_conf = conv[:, :, :, :, 4:5]
|
147 |
+
|
148 |
+
pred_xywh = pred[:, :, :, :, 0:4]
|
149 |
+
pred_conf = pred[:, :, :, :, 4:5]
|
150 |
+
|
151 |
+
label_xywh = label[:, :, :, :, 0:4]
|
152 |
+
respond_bbox = label[:, :, :, :, 4:5]
|
153 |
+
label_prob = label[:, :, :, :, 5:]
|
154 |
+
|
155 |
+
# Coordinate loss
|
156 |
+
ciou = tf.expand_dims(bbox_giou(pred_xywh, label_xywh), axis=-1) # (8, 13, 13, 3, 1)
|
157 |
+
# ciou = tf.expand_dims(bbox_ciou(pred_xywh, label_xywh), axis=-1) # (8, 13, 13, 3, 1)
|
158 |
+
input_size = tf.cast(input_size, tf.float32)
|
159 |
+
|
160 |
+
# loss weight of the gt bbox: 2-(gt area/img area)
|
161 |
+
bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2)
|
162 |
+
ciou_loss = respond_bbox * bbox_loss_scale * (1 - ciou) # iou loss for respond bbox
|
163 |
+
|
164 |
+
# Classification loss for respond bbox
|
165 |
+
prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_prob, logits=conv_raw_prob)
|
166 |
+
|
167 |
+
expand_pred_xywh = pred_xywh[:, :, :, :, np.newaxis, :] # (?, grid_h, grid_w, 3, 1, 4)
|
168 |
+
expand_bboxes = bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :] # (?, 1, 1, 1, 70, 4)
|
169 |
+
iou = bbox_iou(expand_pred_xywh, expand_bboxes) # IoU between all pred bbox and all gt (?, grid_h, grid_w, 3, 70)
|
170 |
+
max_iou = tf.expand_dims(tf.reduce_max(iou, axis=-1), axis=-1) # max iou: (?, grid_h, grid_w, 3, 1)
|
171 |
+
|
172 |
+
# ignore the bbox which is not respond bbox and max iou < threshold
|
173 |
+
respond_bgd = (1.0 - respond_bbox) * tf.cast(max_iou < iou_loss_thresh, tf.float32)
|
174 |
+
|
175 |
+
# Confidence loss
|
176 |
+
conf_focal = tf.pow(respond_bbox - pred_conf, 2)
|
177 |
+
|
178 |
+
conf_loss = conf_focal * (
|
179 |
+
respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
|
180 |
+
+
|
181 |
+
respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
|
182 |
+
)
|
183 |
+
|
184 |
+
ciou_loss = tf.reduce_mean(tf.reduce_sum(ciou_loss, axis=[1, 2, 3, 4]))
|
185 |
+
conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis=[1, 2, 3, 4]))
|
186 |
+
prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis=[1, 2, 3, 4]))
|
187 |
+
|
188 |
+
return ciou_loss, conf_loss, prob_loss
|
189 |
+
|
190 |
+
|
191 |
+
def decode(conv_output, anchors, stride, num_class):
|
192 |
+
conv_shape = tf.shape(conv_output)
|
193 |
+
batch_size = conv_shape[0]
|
194 |
+
output_size = conv_shape[1]
|
195 |
+
anchor_per_scale = len(anchors)
|
196 |
+
conv_output = tf.reshape(conv_output, (batch_size, output_size, output_size, anchor_per_scale, 5 + num_class))
|
197 |
+
conv_raw_dxdy = conv_output[:, :, :, :, 0:2]
|
198 |
+
conv_raw_dwdh = conv_output[:, :, :, :, 2:4]
|
199 |
+
conv_raw_conf = conv_output[:, :, :, :, 4:5]
|
200 |
+
conv_raw_prob = conv_output[:, :, :, :, 5:]
|
201 |
+
y = tf.tile(tf.range(output_size, dtype=tf.int32)[:, tf.newaxis], [1, output_size])
|
202 |
+
x = tf.tile(tf.range(output_size, dtype=tf.int32)[tf.newaxis, :], [output_size, 1])
|
203 |
+
xy_grid = tf.concat([x[:, :, tf.newaxis], y[:, :, tf.newaxis]], axis=-1)
|
204 |
+
xy_grid = tf.tile(xy_grid[tf.newaxis, :, :, tf.newaxis, :], [batch_size, 1, 1, anchor_per_scale, 1])
|
205 |
+
xy_grid = tf.cast(xy_grid, tf.float32)
|
206 |
+
pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * stride
|
207 |
+
pred_wh = (tf.exp(conv_raw_dwdh) * anchors)
|
208 |
+
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
|
209 |
+
pred_conf = tf.sigmoid(conv_raw_conf)
|
210 |
+
pred_prob = tf.sigmoid(conv_raw_prob)
|
211 |
+
return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
|
212 |
+
|
models.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
from glob import glob
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import tensorflow as tf
|
9 |
+
from tensorflow.keras import layers, models, optimizers
|
10 |
+
|
11 |
+
from custom_layers import yolov4_neck, yolov4_head, nms
|
12 |
+
from utils import load_weights, get_detection_data, draw_bbox, voc_ap, draw_plot_func, read_txt_to_list
|
13 |
+
from config import yolo_config
|
14 |
+
from loss import yolo_loss
|
15 |
+
|
16 |
+
|
17 |
+
class Yolov4(object):
|
18 |
+
def __init__(self,
|
19 |
+
weight_path=None,
|
20 |
+
class_name_path='coco_classes.txt',
|
21 |
+
config=yolo_config,
|
22 |
+
):
|
23 |
+
assert config['img_size'][0] == config['img_size'][1], 'not support yet'
|
24 |
+
assert config['img_size'][0] % config['strides'][-1] == 0, 'must be a multiple of last stride'
|
25 |
+
self.class_names = [line.strip() for line in open(class_name_path).readlines()]
|
26 |
+
self.img_size = yolo_config['img_size']
|
27 |
+
self.num_classes = len(self.class_names)
|
28 |
+
self.weight_path = weight_path
|
29 |
+
self.anchors = np.array(yolo_config['anchors']).reshape((3, 3, 2))
|
30 |
+
self.xyscale = yolo_config['xyscale']
|
31 |
+
self.strides = yolo_config['strides']
|
32 |
+
self.output_sizes = [self.img_size[0] // s for s in self.strides]
|
33 |
+
self.class_color = {name: list(np.random.random(size=3)*255) for name in self.class_names}
|
34 |
+
# Training
|
35 |
+
self.max_boxes = yolo_config['max_boxes']
|
36 |
+
self.iou_loss_thresh = yolo_config['iou_loss_thresh']
|
37 |
+
self.config = yolo_config
|
38 |
+
assert self.num_classes > 0, 'no classes detected!'
|
39 |
+
|
40 |
+
tf.keras.backend.clear_session()
|
41 |
+
if yolo_config['num_gpu'] > 1:
|
42 |
+
mirrored_strategy = tf.distribute.MirroredStrategy()
|
43 |
+
with mirrored_strategy.scope():
|
44 |
+
self.build_model(load_pretrained=True if self.weight_path else False)
|
45 |
+
else:
|
46 |
+
self.build_model(load_pretrained=True if self.weight_path else False)
|
47 |
+
|
48 |
+
def build_model(self, load_pretrained=True):
|
49 |
+
# core yolo model
|
50 |
+
input_layer = layers.Input(self.img_size)
|
51 |
+
yolov4_output = yolov4_neck(input_layer, self.num_classes)
|
52 |
+
self.yolo_model = models.Model(input_layer, yolov4_output)
|
53 |
+
|
54 |
+
# Build training model
|
55 |
+
y_true = [
|
56 |
+
layers.Input(name='input_2', shape=(52, 52, 3, (self.num_classes + 5))), # label small boxes
|
57 |
+
layers.Input(name='input_3', shape=(26, 26, 3, (self.num_classes + 5))), # label medium boxes
|
58 |
+
layers.Input(name='input_4', shape=(13, 13, 3, (self.num_classes + 5))), # label large boxes
|
59 |
+
layers.Input(name='input_5', shape=(self.max_boxes, 4)), # true bboxes
|
60 |
+
]
|
61 |
+
loss_list = tf.keras.layers.Lambda(yolo_loss, name='yolo_loss',
|
62 |
+
arguments={'num_classes': self.num_classes,
|
63 |
+
'iou_loss_thresh': self.iou_loss_thresh,
|
64 |
+
'anchors': self.anchors})([*self.yolo_model.output, *y_true])
|
65 |
+
self.training_model = models.Model([self.yolo_model.input, *y_true], loss_list)
|
66 |
+
|
67 |
+
# Build inference model
|
68 |
+
yolov4_output = yolov4_head(yolov4_output, self.num_classes, self.anchors, self.xyscale)
|
69 |
+
# output: [boxes, scores, classes, valid_detections]
|
70 |
+
self.inference_model = models.Model(input_layer,
|
71 |
+
nms(yolov4_output, self.img_size, self.num_classes,
|
72 |
+
iou_threshold=self.config['iou_threshold'],
|
73 |
+
score_threshold=self.config['score_threshold']))
|
74 |
+
|
75 |
+
if load_pretrained and self.weight_path and self.weight_path.endswith('.weights'):
|
76 |
+
if self.weight_path.endswith('.weights'):
|
77 |
+
load_weights(self.yolo_model, self.weight_path)
|
78 |
+
print(f'load from {self.weight_path}')
|
79 |
+
elif self.weight_path.endswith('.h5'):
|
80 |
+
self.training_model.load_weights(self.weight_path)
|
81 |
+
print(f'load from {self.weight_path}')
|
82 |
+
|
83 |
+
self.training_model.compile(optimizer=optimizers.Adam(lr=1e-3),
|
84 |
+
loss={'yolo_loss': lambda y_true, y_pred: y_pred})
|
85 |
+
|
86 |
+
def load_model(self, path):
|
87 |
+
self.yolo_model = models.load_model(path, compile=False)
|
88 |
+
yolov4_output = yolov4_head(self.yolo_model.output, self.num_classes, self.anchors, self.xyscale)
|
89 |
+
self.inference_model = models.Model(self.yolo_model.input,
|
90 |
+
nms(yolov4_output, self.img_size, self.num_classes)) # [boxes, scores, classes, valid_detections]
|
91 |
+
|
92 |
+
def save_model(self, path):
|
93 |
+
self.yolo_model.save(path)
|
94 |
+
|
95 |
+
def preprocess_img(self, img):
|
96 |
+
img = cv2.resize(img, self.img_size[:2])
|
97 |
+
img = img / 255.
|
98 |
+
return img
|
99 |
+
|
100 |
+
def fit(self, train_data_gen, epochs, val_data_gen=None, initial_epoch=0, callbacks=None):
|
101 |
+
self.training_model.fit(train_data_gen,
|
102 |
+
steps_per_epoch=len(train_data_gen),
|
103 |
+
validation_data=val_data_gen,
|
104 |
+
validation_steps=len(val_data_gen),
|
105 |
+
epochs=epochs,
|
106 |
+
callbacks=callbacks,
|
107 |
+
initial_epoch=initial_epoch)
|
108 |
+
# raw_img: RGB
|
109 |
+
def predict_img(self, raw_img, random_color=True, plot_img=True, figsize=(10, 10), show_text=True, return_output=True):
|
110 |
+
print('img shape: ', raw_img.shape)
|
111 |
+
img = self.preprocess_img(raw_img)
|
112 |
+
imgs = np.expand_dims(img, axis=0)
|
113 |
+
pred_output = self.inference_model.predict(imgs)
|
114 |
+
detections = get_detection_data(img=raw_img,
|
115 |
+
model_outputs=pred_output,
|
116 |
+
class_names=self.class_names)
|
117 |
+
|
118 |
+
output_img = draw_bbox(raw_img, detections, cmap=self.class_color, random_color=random_color, figsize=figsize,
|
119 |
+
show_text=show_text, show_img=False)
|
120 |
+
if return_output:
|
121 |
+
return output_img, detections
|
122 |
+
else:
|
123 |
+
return detections
|
124 |
+
|
125 |
+
def predict(self, img_path, random_color=True, plot_img=True, figsize=(10, 10), show_text=True):
|
126 |
+
raw_img = img_path
|
127 |
+
return self.predict_img(raw_img, random_color, plot_img, figsize, show_text)
|
128 |
+
|
129 |
+
def export_gt(self, annotation_path, gt_folder_path):
|
130 |
+
with open(annotation_path) as file:
|
131 |
+
for line in file:
|
132 |
+
line = line.split(' ')
|
133 |
+
filename = line[0].split(os.sep)[-1].split('.')[0]
|
134 |
+
objs = line[1:]
|
135 |
+
# export txt file
|
136 |
+
with open(os.path.join(gt_folder_path, filename + '.txt'), 'w') as output_file:
|
137 |
+
for obj in objs:
|
138 |
+
x_min, y_min, x_max, y_max, class_id = [float(o) for o in obj.strip().split(',')]
|
139 |
+
output_file.write(f'{self.class_names[int(class_id)]} {x_min} {y_min} {x_max} {y_max}\n')
|
140 |
+
|
141 |
+
def export_prediction(self, annotation_path, pred_folder_path, img_folder_path, bs=2):
|
142 |
+
with open(annotation_path) as file:
|
143 |
+
img_paths = [os.path.join(img_folder_path, line.split(' ')[0].split(os.sep)[-1]) for line in file]
|
144 |
+
# print(img_paths[:20])
|
145 |
+
for batch_idx in tqdm(range(0, len(img_paths), bs)):
|
146 |
+
# print(len(img_paths), batch_idx, batch_idx*bs, (batch_idx+1)*bs)
|
147 |
+
paths = img_paths[batch_idx:batch_idx+bs]
|
148 |
+
# print(paths)
|
149 |
+
# read and process img
|
150 |
+
imgs = np.zeros((len(paths), *self.img_size))
|
151 |
+
raw_img_shapes = []
|
152 |
+
for j, path in enumerate(paths):
|
153 |
+
img = cv2.imread(path)
|
154 |
+
raw_img_shapes.append(img.shape)
|
155 |
+
img = self.preprocess_img(img)
|
156 |
+
imgs[j] = img
|
157 |
+
|
158 |
+
# process batch output
|
159 |
+
b_boxes, b_scores, b_classes, b_valid_detections = self.inference_model.predict(imgs)
|
160 |
+
for k in range(len(paths)):
|
161 |
+
num_boxes = b_valid_detections[k]
|
162 |
+
raw_img_shape = raw_img_shapes[k]
|
163 |
+
boxes = b_boxes[k, :num_boxes]
|
164 |
+
classes = b_classes[k, :num_boxes]
|
165 |
+
scores = b_scores[k, :num_boxes]
|
166 |
+
# print(raw_img_shape)
|
167 |
+
boxes[:, [0, 2]] = (boxes[:, [0, 2]] * raw_img_shape[1]) # w
|
168 |
+
boxes[:, [1, 3]] = (boxes[:, [1, 3]] * raw_img_shape[0]) # h
|
169 |
+
cls_names = [self.class_names[int(c)] for c in classes]
|
170 |
+
# print(raw_img_shape, boxes.astype(int), cls_names, scores)
|
171 |
+
|
172 |
+
img_path = paths[k]
|
173 |
+
filename = img_path.split(os.sep)[-1].split('.')[0]
|
174 |
+
# print(filename)
|
175 |
+
output_path = os.path.join(pred_folder_path, filename+'.txt')
|
176 |
+
with open(output_path, 'w') as pred_file:
|
177 |
+
for box_idx in range(num_boxes):
|
178 |
+
b = boxes[box_idx]
|
179 |
+
pred_file.write(f'{cls_names[box_idx]} {scores[box_idx]} {b[0]} {b[1]} {b[2]} {b[3]}\n')
|
180 |
+
|
181 |
+
|
182 |
+
def eval_map(self, gt_folder_path, pred_folder_path, temp_json_folder_path, output_files_path):
|
183 |
+
"""Process Gt"""
|
184 |
+
ground_truth_files_list = glob(gt_folder_path + '/*.txt')
|
185 |
+
assert len(ground_truth_files_list) > 0, 'no ground truth file'
|
186 |
+
ground_truth_files_list.sort()
|
187 |
+
# dictionary with counter per class
|
188 |
+
gt_counter_per_class = {}
|
189 |
+
counter_images_per_class = {}
|
190 |
+
|
191 |
+
gt_files = []
|
192 |
+
for txt_file in ground_truth_files_list:
|
193 |
+
file_id = txt_file.split(".txt", 1)[0]
|
194 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
195 |
+
# check if there is a correspondent detection-results file
|
196 |
+
temp_path = os.path.join(pred_folder_path, (file_id + ".txt"))
|
197 |
+
assert os.path.exists(temp_path), "Error. File not found: {}\n".format(temp_path)
|
198 |
+
lines_list = read_txt_to_list(txt_file)
|
199 |
+
# create ground-truth dictionary
|
200 |
+
bounding_boxes = []
|
201 |
+
is_difficult = False
|
202 |
+
already_seen_classes = []
|
203 |
+
for line in lines_list:
|
204 |
+
class_name, left, top, right, bottom = line.split()
|
205 |
+
# check if class is in the ignore list, if yes skip
|
206 |
+
bbox = left + " " + top + " " + right + " " + bottom
|
207 |
+
bounding_boxes.append({"class_name": class_name, "bbox": bbox, "used": False})
|
208 |
+
# count that object
|
209 |
+
if class_name in gt_counter_per_class:
|
210 |
+
gt_counter_per_class[class_name] += 1
|
211 |
+
else:
|
212 |
+
# if class didn't exist yet
|
213 |
+
gt_counter_per_class[class_name] = 1
|
214 |
+
|
215 |
+
if class_name not in already_seen_classes:
|
216 |
+
if class_name in counter_images_per_class:
|
217 |
+
counter_images_per_class[class_name] += 1
|
218 |
+
else:
|
219 |
+
# if class didn't exist yet
|
220 |
+
counter_images_per_class[class_name] = 1
|
221 |
+
already_seen_classes.append(class_name)
|
222 |
+
|
223 |
+
# dump bounding_boxes into a ".json" file
|
224 |
+
new_temp_file = os.path.join(temp_json_folder_path, file_id+"_ground_truth.json") #TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
|
225 |
+
gt_files.append(new_temp_file)
|
226 |
+
with open(new_temp_file, 'w') as outfile:
|
227 |
+
json.dump(bounding_boxes, outfile)
|
228 |
+
|
229 |
+
gt_classes = list(gt_counter_per_class.keys())
|
230 |
+
# let's sort the classes alphabetically
|
231 |
+
gt_classes = sorted(gt_classes)
|
232 |
+
n_classes = len(gt_classes)
|
233 |
+
print(gt_classes, gt_counter_per_class)
|
234 |
+
|
235 |
+
"""Process prediction"""
|
236 |
+
|
237 |
+
dr_files_list = sorted(glob(os.path.join(pred_folder_path, '*.txt')))
|
238 |
+
|
239 |
+
for class_index, class_name in enumerate(gt_classes):
|
240 |
+
bounding_boxes = []
|
241 |
+
for txt_file in dr_files_list:
|
242 |
+
# the first time it checks if all the corresponding ground-truth files exist
|
243 |
+
file_id = txt_file.split(".txt", 1)[0]
|
244 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
245 |
+
temp_path = os.path.join(gt_folder_path, (file_id + ".txt"))
|
246 |
+
if class_index == 0:
|
247 |
+
if not os.path.exists(temp_path):
|
248 |
+
error_msg = f"Error. File not found: {temp_path}\n"
|
249 |
+
print(error_msg)
|
250 |
+
lines = read_txt_to_list(txt_file)
|
251 |
+
for line in lines:
|
252 |
+
try:
|
253 |
+
tmp_class_name, confidence, left, top, right, bottom = line.split()
|
254 |
+
except ValueError:
|
255 |
+
error_msg = f"""Error: File {txt_file} in the wrong format.\n
|
256 |
+
Expected: <class_name> <confidence> <left> <top> <right> <bottom>\n
|
257 |
+
Received: {line} \n"""
|
258 |
+
print(error_msg)
|
259 |
+
if tmp_class_name == class_name:
|
260 |
+
# print("match")
|
261 |
+
bbox = left + " " + top + " " + right + " " + bottom
|
262 |
+
bounding_boxes.append({"confidence": confidence, "file_id": file_id, "bbox": bbox})
|
263 |
+
# sort detection-results by decreasing confidence
|
264 |
+
bounding_boxes.sort(key=lambda x: float(x['confidence']), reverse=True)
|
265 |
+
with open(temp_json_folder_path + "/" + class_name + "_dr.json", 'w') as outfile:
|
266 |
+
json.dump(bounding_boxes, outfile)
|
267 |
+
|
268 |
+
"""
|
269 |
+
Calculate the AP for each class
|
270 |
+
"""
|
271 |
+
sum_AP = 0.0
|
272 |
+
ap_dictionary = {}
|
273 |
+
# open file to store the output
|
274 |
+
with open(output_files_path + "/output.txt", 'w') as output_file:
|
275 |
+
output_file.write("# AP and precision/recall per class\n")
|
276 |
+
count_true_positives = {}
|
277 |
+
for class_index, class_name in enumerate(gt_classes):
|
278 |
+
count_true_positives[class_name] = 0
|
279 |
+
"""
|
280 |
+
Load detection-results of that class
|
281 |
+
"""
|
282 |
+
dr_file = temp_json_folder_path + "/" + class_name + "_dr.json"
|
283 |
+
dr_data = json.load(open(dr_file))
|
284 |
+
|
285 |
+
"""
|
286 |
+
Assign detection-results to ground-truth objects
|
287 |
+
"""
|
288 |
+
nd = len(dr_data)
|
289 |
+
tp = [0] * nd # creates an array of zeros of size nd
|
290 |
+
fp = [0] * nd
|
291 |
+
for idx, detection in enumerate(dr_data):
|
292 |
+
file_id = detection["file_id"]
|
293 |
+
gt_file = temp_json_folder_path + "/" + file_id + "_ground_truth.json"
|
294 |
+
ground_truth_data = json.load(open(gt_file))
|
295 |
+
ovmax = -1
|
296 |
+
gt_match = -1
|
297 |
+
# load detected object bounding-box
|
298 |
+
bb = [float(x) for x in detection["bbox"].split()]
|
299 |
+
for obj in ground_truth_data:
|
300 |
+
# look for a class_name match
|
301 |
+
if obj["class_name"] == class_name:
|
302 |
+
bbgt = [float(x) for x in obj["bbox"].split()]
|
303 |
+
bi = [max(bb[0], bbgt[0]), max(bb[1], bbgt[1]), min(bb[2], bbgt[2]), min(bb[3], bbgt[3])]
|
304 |
+
iw = bi[2] - bi[0] + 1
|
305 |
+
ih = bi[3] - bi[1] + 1
|
306 |
+
if iw > 0 and ih > 0:
|
307 |
+
# compute overlap (IoU) = area of intersection / area of union
|
308 |
+
ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + \
|
309 |
+
(bbgt[2] - bbgt[0]+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
|
310 |
+
ov = iw * ih / ua
|
311 |
+
if ov > ovmax:
|
312 |
+
ovmax = ov
|
313 |
+
gt_match = obj
|
314 |
+
|
315 |
+
min_overlap = 0.5
|
316 |
+
if ovmax >= min_overlap:
|
317 |
+
# if "difficult" not in gt_match:
|
318 |
+
if not bool(gt_match["used"]):
|
319 |
+
# true positive
|
320 |
+
tp[idx] = 1
|
321 |
+
gt_match["used"] = True
|
322 |
+
count_true_positives[class_name] += 1
|
323 |
+
# update the ".json" file
|
324 |
+
with open(gt_file, 'w') as f:
|
325 |
+
f.write(json.dumps(ground_truth_data))
|
326 |
+
else:
|
327 |
+
# false positive (multiple detection)
|
328 |
+
fp[idx] = 1
|
329 |
+
else:
|
330 |
+
fp[idx] = 1
|
331 |
+
|
332 |
+
|
333 |
+
# compute precision/recall
|
334 |
+
cumsum = 0
|
335 |
+
for idx, val in enumerate(fp):
|
336 |
+
fp[idx] += cumsum
|
337 |
+
cumsum += val
|
338 |
+
print('fp ', cumsum)
|
339 |
+
cumsum = 0
|
340 |
+
for idx, val in enumerate(tp):
|
341 |
+
tp[idx] += cumsum
|
342 |
+
cumsum += val
|
343 |
+
print('tp ', cumsum)
|
344 |
+
rec = tp[:]
|
345 |
+
for idx, val in enumerate(tp):
|
346 |
+
rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
|
347 |
+
print('recall ', cumsum)
|
348 |
+
prec = tp[:]
|
349 |
+
for idx, val in enumerate(tp):
|
350 |
+
prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
|
351 |
+
print('prec ', cumsum)
|
352 |
+
|
353 |
+
ap, mrec, mprec = voc_ap(rec[:], prec[:])
|
354 |
+
sum_AP += ap
|
355 |
+
text = "{0:.2f}%".format(
|
356 |
+
ap * 100) + " = " + class_name + " AP " # class_name + " AP = {0:.2f}%".format(ap*100)
|
357 |
+
|
358 |
+
print(text)
|
359 |
+
ap_dictionary[class_name] = ap
|
360 |
+
|
361 |
+
n_images = counter_images_per_class[class_name]
|
362 |
+
# lamr, mr, fppi = log_average_miss_rate(np.array(prec), np.array(rec), n_images)
|
363 |
+
# lamr_dictionary[class_name] = lamr
|
364 |
+
|
365 |
+
"""
|
366 |
+
Draw plot
|
367 |
+
"""
|
368 |
+
if True:
|
369 |
+
plt.plot(rec, prec, '-o')
|
370 |
+
# add a new penultimate point to the list (mrec[-2], 0.0)
|
371 |
+
# since the last line segment (and respective area) do not affect the AP value
|
372 |
+
area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
|
373 |
+
area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
|
374 |
+
plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
|
375 |
+
# set window title
|
376 |
+
fig = plt.gcf() # gcf - get current figure
|
377 |
+
fig.canvas.set_window_title('AP ' + class_name)
|
378 |
+
# set plot title
|
379 |
+
plt.title('class: ' + text)
|
380 |
+
# plt.suptitle('This is a somewhat long figure title', fontsize=16)
|
381 |
+
# set axis titles
|
382 |
+
plt.xlabel('Recall')
|
383 |
+
plt.ylabel('Precision')
|
384 |
+
# optional - set axes
|
385 |
+
axes = plt.gca() # gca - get current axes
|
386 |
+
axes.set_xlim([0.0, 1.0])
|
387 |
+
axes.set_ylim([0.0, 1.05]) # .05 to give some extra space
|
388 |
+
# Alternative option -> wait for button to be pressed
|
389 |
+
# while not plt.waitforbuttonpress(): pass # wait for key display
|
390 |
+
# Alternative option -> normal display
|
391 |
+
plt.show()
|
392 |
+
# save the plot
|
393 |
+
# fig.savefig(output_files_path + "/classes/" + class_name + ".png")
|
394 |
+
# plt.cla() # clear axes for next plot
|
395 |
+
|
396 |
+
# if show_animation:
|
397 |
+
# cv2.destroyAllWindows()
|
398 |
+
|
399 |
+
output_file.write("\n# mAP of all classes\n")
|
400 |
+
mAP = sum_AP / n_classes
|
401 |
+
text = "mAP = {0:.2f}%".format(mAP * 100)
|
402 |
+
output_file.write(text + "\n")
|
403 |
+
print(text)
|
404 |
+
|
405 |
+
"""
|
406 |
+
Count total of detection-results
|
407 |
+
"""
|
408 |
+
# iterate through all the files
|
409 |
+
det_counter_per_class = {}
|
410 |
+
for txt_file in dr_files_list:
|
411 |
+
# get lines to list
|
412 |
+
lines_list = read_txt_to_list(txt_file)
|
413 |
+
for line in lines_list:
|
414 |
+
class_name = line.split()[0]
|
415 |
+
# check if class is in the ignore list, if yes skip
|
416 |
+
# if class_name in args.ignore:
|
417 |
+
# continue
|
418 |
+
# count that object
|
419 |
+
if class_name in det_counter_per_class:
|
420 |
+
det_counter_per_class[class_name] += 1
|
421 |
+
else:
|
422 |
+
# if class didn't exist yet
|
423 |
+
det_counter_per_class[class_name] = 1
|
424 |
+
# print(det_counter_per_class)
|
425 |
+
dr_classes = list(det_counter_per_class.keys())
|
426 |
+
|
427 |
+
"""
|
428 |
+
Plot the total number of occurences of each class in the ground-truth
|
429 |
+
"""
|
430 |
+
if True:
|
431 |
+
window_title = "ground-truth-info"
|
432 |
+
plot_title = "ground-truth\n"
|
433 |
+
plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
|
434 |
+
x_label = "Number of objects per class"
|
435 |
+
output_path = output_files_path + "/ground-truth-info.png"
|
436 |
+
to_show = False
|
437 |
+
plot_color = 'forestgreen'
|
438 |
+
draw_plot_func(
|
439 |
+
gt_counter_per_class,
|
440 |
+
n_classes,
|
441 |
+
window_title,
|
442 |
+
plot_title,
|
443 |
+
x_label,
|
444 |
+
output_path,
|
445 |
+
to_show,
|
446 |
+
plot_color,
|
447 |
+
'',
|
448 |
+
)
|
449 |
+
|
450 |
+
"""
|
451 |
+
Finish counting true positives
|
452 |
+
"""
|
453 |
+
for class_name in dr_classes:
|
454 |
+
# if class exists in detection-result but not in ground-truth then there are no true positives in that class
|
455 |
+
if class_name not in gt_classes:
|
456 |
+
count_true_positives[class_name] = 0
|
457 |
+
# print(count_true_positives)
|
458 |
+
|
459 |
+
"""
|
460 |
+
Plot the total number of occurences of each class in the "detection-results" folder
|
461 |
+
"""
|
462 |
+
if True:
|
463 |
+
window_title = "detection-results-info"
|
464 |
+
# Plot title
|
465 |
+
plot_title = "detection-results\n"
|
466 |
+
plot_title += "(" + str(len(dr_files_list)) + " files and "
|
467 |
+
count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
|
468 |
+
plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
|
469 |
+
# end Plot title
|
470 |
+
x_label = "Number of objects per class"
|
471 |
+
output_path = output_files_path + "/detection-results-info.png"
|
472 |
+
to_show = False
|
473 |
+
plot_color = 'forestgreen'
|
474 |
+
true_p_bar = count_true_positives
|
475 |
+
draw_plot_func(
|
476 |
+
det_counter_per_class,
|
477 |
+
len(det_counter_per_class),
|
478 |
+
window_title,
|
479 |
+
plot_title,
|
480 |
+
x_label,
|
481 |
+
output_path,
|
482 |
+
to_show,
|
483 |
+
plot_color,
|
484 |
+
true_p_bar
|
485 |
+
)
|
486 |
+
|
487 |
+
"""
|
488 |
+
Draw mAP plot (Show AP's of all classes in decreasing order)
|
489 |
+
"""
|
490 |
+
if True:
|
491 |
+
window_title = "mAP"
|
492 |
+
plot_title = "mAP = {0:.2f}%".format(mAP * 100)
|
493 |
+
x_label = "Average Precision"
|
494 |
+
output_path = output_files_path + "/mAP.png"
|
495 |
+
to_show = True
|
496 |
+
plot_color = 'royalblue'
|
497 |
+
draw_plot_func(
|
498 |
+
ap_dictionary,
|
499 |
+
n_classes,
|
500 |
+
window_title,
|
501 |
+
plot_title,
|
502 |
+
x_label,
|
503 |
+
output_path,
|
504 |
+
to_show,
|
505 |
+
plot_color,
|
506 |
+
""
|
507 |
+
)
|
508 |
+
|
509 |
+
def predict_raw(self, img_path):
|
510 |
+
raw_img = cv2.imread(img_path)
|
511 |
+
print('img shape: ', raw_img.shape)
|
512 |
+
img = self.preprocess_img(raw_img)
|
513 |
+
imgs = np.expand_dims(img, axis=0)
|
514 |
+
return self.yolo_model.predict(imgs)
|
515 |
+
|
516 |
+
def predict_nonms(self, img_path, iou_threshold=0.413, score_threshold=0.1):
|
517 |
+
raw_img = cv2.imread(img_path)
|
518 |
+
print('img shape: ', raw_img.shape)
|
519 |
+
img = self.preprocess_img(raw_img)
|
520 |
+
imgs = np.expand_dims(img, axis=0)
|
521 |
+
yolov4_output = self.yolo_model.predict(imgs)
|
522 |
+
output = yolov4_head(yolov4_output, self.num_classes, self.anchors, self.xyscale)
|
523 |
+
pred_output = nms(output, self.img_size, self.num_classes, iou_threshold, score_threshold)
|
524 |
+
pred_output = [p.numpy() for p in pred_output]
|
525 |
+
detections = get_detection_data(img=raw_img,
|
526 |
+
model_outputs=pred_output,
|
527 |
+
class_names=self.class_names)
|
528 |
+
draw_bbox(raw_img, detections, cmap=self.class_color, random_color=True)
|
529 |
+
return detections
|
530 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.3.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
anyio==3.6.2
|
5 |
+
astunparse==1.6.3
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==22.1.0
|
8 |
+
bcrypt==4.0.1
|
9 |
+
cachetools==5.2.0
|
10 |
+
certifi==2022.9.24
|
11 |
+
cffi==1.15.1
|
12 |
+
charset-normalizer==2.1.1
|
13 |
+
click==8.1.3
|
14 |
+
colorama==0.4.6
|
15 |
+
contourpy==1.0.6
|
16 |
+
cryptography==38.0.3
|
17 |
+
cycler==0.11.0
|
18 |
+
fastapi==0.87.0
|
19 |
+
ffmpy==0.3.0
|
20 |
+
flatbuffers==22.10.26
|
21 |
+
fonttools==4.38.0
|
22 |
+
frozenlist==1.3.3
|
23 |
+
fsspec==2022.11.0
|
24 |
+
gast==0.4.0
|
25 |
+
google-auth==2.14.1
|
26 |
+
google-auth-oauthlib==0.4.6
|
27 |
+
google-pasta==0.2.0
|
28 |
+
gradio==3.10.0
|
29 |
+
grpcio==1.50.0
|
30 |
+
h11==0.12.0
|
31 |
+
h5py==3.7.0
|
32 |
+
httpcore==0.15.0
|
33 |
+
httpx==0.23.1
|
34 |
+
idna==3.4
|
35 |
+
importlib-metadata==5.0.0
|
36 |
+
Jinja2==3.1.2
|
37 |
+
joblib==1.2.0
|
38 |
+
keras==2.11.0
|
39 |
+
kiwisolver==1.4.4
|
40 |
+
libclang==14.0.6
|
41 |
+
linkify-it-py==1.0.3
|
42 |
+
Markdown==3.4.1
|
43 |
+
markdown-it-py==2.1.0
|
44 |
+
MarkupSafe==2.1.1
|
45 |
+
matplotlib==3.6.2
|
46 |
+
mdit-py-plugins==0.3.1
|
47 |
+
mdurl==0.1.2
|
48 |
+
multidict==6.0.2
|
49 |
+
numpy==1.23.4
|
50 |
+
oauthlib==3.2.2
|
51 |
+
opencv-python==4.6.0.66
|
52 |
+
opt-einsum==3.3.0
|
53 |
+
orjson==3.8.1
|
54 |
+
packaging==21.3
|
55 |
+
pandas==1.5.1
|
56 |
+
paramiko==2.12.0
|
57 |
+
Pillow==9.3.0
|
58 |
+
protobuf==3.19.6
|
59 |
+
pyasn1==0.4.8
|
60 |
+
pyasn1-modules==0.2.8
|
61 |
+
pycparser==2.21
|
62 |
+
pycryptodome==3.15.0
|
63 |
+
pydantic==1.10.2
|
64 |
+
pydub==0.25.1
|
65 |
+
PyNaCl==1.5.0
|
66 |
+
pyparsing==3.0.9
|
67 |
+
python-dateutil==2.8.2
|
68 |
+
python-multipart==0.0.5
|
69 |
+
pytz==2022.6
|
70 |
+
PyYAML==6.0
|
71 |
+
requests==2.28.1
|
72 |
+
requests-oauthlib==1.3.1
|
73 |
+
rfc3986==1.5.0
|
74 |
+
rsa==4.9
|
75 |
+
scikit-learn==1.1.3
|
76 |
+
scipy==1.9.3
|
77 |
+
six==1.16.0
|
78 |
+
sniffio==1.3.0
|
79 |
+
starlette==0.21.0
|
80 |
+
tensorboard==2.11.0
|
81 |
+
tensorboard-data-server==0.6.1
|
82 |
+
tensorboard-plugin-wit==1.8.1
|
83 |
+
tensorflow==2.11.0
|
84 |
+
tensorflow-estimator==2.11.0
|
85 |
+
termcolor==2.1.0
|
86 |
+
threadpoolctl==3.1.0
|
87 |
+
tqdm==4.64.1
|
88 |
+
typing_extensions==4.4.0
|
89 |
+
uc-micro-py==1.0.1
|
90 |
+
urllib3==1.26.12
|
91 |
+
uvicorn==0.19.0
|
92 |
+
websockets==10.4
|
93 |
+
Werkzeug==2.2.2
|
94 |
+
wrapt==1.14.1
|
95 |
+
yarl==1.8.1
|
96 |
+
zipp==3.10.0
|
utils.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import pandas as pd
|
4 |
+
import operator
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import os
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from tensorflow.keras.utils import Sequence
|
9 |
+
from config import yolo_config
|
10 |
+
|
11 |
+
|
12 |
+
def load_weights(model, weights_file_path):
|
13 |
+
conv_layer_size = 110
|
14 |
+
conv_output_idxs = [93, 101, 109]
|
15 |
+
with open(weights_file_path, 'rb') as file:
|
16 |
+
major, minor, revision, seen, _ = np.fromfile(file, dtype=np.int32, count=5)
|
17 |
+
|
18 |
+
bn_idx = 0
|
19 |
+
for conv_idx in range(conv_layer_size):
|
20 |
+
conv_layer_name = f'conv2d_{conv_idx}' if conv_idx > 0 else 'conv2d'
|
21 |
+
bn_layer_name = f'batch_normalization_{bn_idx}' if bn_idx > 0 else 'batch_normalization'
|
22 |
+
|
23 |
+
conv_layer = model.get_layer(conv_layer_name)
|
24 |
+
filters = conv_layer.filters
|
25 |
+
kernel_size = conv_layer.kernel_size[0]
|
26 |
+
input_dims = conv_layer.input_shape[-1]
|
27 |
+
|
28 |
+
if conv_idx not in conv_output_idxs:
|
29 |
+
# darknet bn layer weights: [beta, gamma, mean, variance]
|
30 |
+
bn_weights = np.fromfile(file, dtype=np.float32, count=4 * filters)
|
31 |
+
# tf bn layer weights: [gamma, beta, mean, variance]
|
32 |
+
bn_weights = bn_weights.reshape((4, filters))[[1, 0, 2, 3]]
|
33 |
+
bn_layer = model.get_layer(bn_layer_name)
|
34 |
+
bn_idx += 1
|
35 |
+
else:
|
36 |
+
conv_bias = np.fromfile(file, dtype=np.float32, count=filters)
|
37 |
+
|
38 |
+
# darknet shape: (out_dim, input_dims, height, width)
|
39 |
+
# tf shape: (height, width, input_dims, out_dim)
|
40 |
+
conv_shape = (filters, input_dims, kernel_size, kernel_size)
|
41 |
+
conv_weights = np.fromfile(file, dtype=np.float32, count=np.product(conv_shape))
|
42 |
+
conv_weights = conv_weights.reshape(conv_shape).transpose([2, 3, 1, 0])
|
43 |
+
|
44 |
+
if conv_idx not in conv_output_idxs:
|
45 |
+
conv_layer.set_weights([conv_weights])
|
46 |
+
bn_layer.set_weights(bn_weights)
|
47 |
+
else:
|
48 |
+
conv_layer.set_weights([conv_weights, conv_bias])
|
49 |
+
|
50 |
+
if len(file.read()) == 0:
|
51 |
+
print('all weights read')
|
52 |
+
else:
|
53 |
+
print(f'failed to read all weights, # of unread weights: {len(file.read())}')
|
54 |
+
|
55 |
+
|
56 |
+
def get_detection_data(img, model_outputs, class_names):
|
57 |
+
"""
|
58 |
+
|
59 |
+
:param img: target raw image
|
60 |
+
:param model_outputs: outputs from inference_model
|
61 |
+
:param class_names: list of object class names
|
62 |
+
:return:
|
63 |
+
"""
|
64 |
+
|
65 |
+
num_bboxes = model_outputs[-1][0]
|
66 |
+
boxes, scores, classes = [output[0][:num_bboxes] for output in model_outputs[:-1]]
|
67 |
+
|
68 |
+
h, w = img.shape[:2]
|
69 |
+
df = pd.DataFrame(boxes, columns=['x1', 'y1', 'x2', 'y2'])
|
70 |
+
df[['x1', 'x2']] = (df[['x1', 'x2']] * w).astype('int64')
|
71 |
+
df[['y1', 'y2']] = (df[['y1', 'y2']] * h).astype('int64')
|
72 |
+
df['class_name'] = np.array(class_names)[classes.astype('int64')]
|
73 |
+
df['score'] = scores
|
74 |
+
df['w'] = df['x2'] - df['x1']
|
75 |
+
df['h'] = df['y2'] - df['y1']
|
76 |
+
|
77 |
+
print(f'# of bboxes: {num_bboxes}')
|
78 |
+
return df
|
79 |
+
|
80 |
+
def read_annotation_lines(annotation_path, test_size=None, random_seed=5566):
|
81 |
+
with open(annotation_path) as f:
|
82 |
+
lines = f.readlines()
|
83 |
+
if test_size:
|
84 |
+
return train_test_split(lines, test_size=test_size, random_state=random_seed)
|
85 |
+
else:
|
86 |
+
return lines
|
87 |
+
|
88 |
+
def draw_bbox(img, detections, cmap, random_color=True, figsize=(10, 10), show_img=True, show_text=True):
|
89 |
+
"""
|
90 |
+
Draw bounding boxes on the img.
|
91 |
+
:param img: BGR img.
|
92 |
+
:param detections: pandas DataFrame containing detections
|
93 |
+
:param random_color: assign random color for each objects
|
94 |
+
:param cmap: object colormap
|
95 |
+
:param plot_img: if plot img with bboxes
|
96 |
+
:return: None
|
97 |
+
"""
|
98 |
+
img = np.array(img)
|
99 |
+
scale = max(img.shape[0:2]) / 416
|
100 |
+
line_width = int(2 * scale)
|
101 |
+
|
102 |
+
for _, row in detections.iterrows():
|
103 |
+
x1, y1, x2, y2, cls, score, w, h = row.values
|
104 |
+
color = list(np.random.random(size=3) * 255) if random_color else cmap[cls]
|
105 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color, line_width)
|
106 |
+
if show_text:
|
107 |
+
text = f'{cls} {score:.2f}'
|
108 |
+
font = cv2.FONT_HERSHEY_DUPLEX
|
109 |
+
font_scale = max(0.3 * scale, 0.3)
|
110 |
+
thickness = max(int(1 * scale), 1)
|
111 |
+
(text_width, text_height) = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)[0]
|
112 |
+
cv2.rectangle(img, (x1 - line_width//2, y1 - text_height), (x1 + text_width, y1), color, cv2.FILLED)
|
113 |
+
cv2.putText(img, text, (x1, y1), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
114 |
+
if show_img:
|
115 |
+
plt.figure(figsize=figsize)
|
116 |
+
plt.imshow(img)
|
117 |
+
plt.show()
|
118 |
+
return img
|
119 |
+
|
120 |
+
|
121 |
+
class DataGenerator(Sequence):
|
122 |
+
"""
|
123 |
+
Generates data for Keras
|
124 |
+
ref: https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
|
125 |
+
"""
|
126 |
+
def __init__(self,
|
127 |
+
annotation_lines,
|
128 |
+
class_name_path,
|
129 |
+
folder_path,
|
130 |
+
max_boxes=100,
|
131 |
+
shuffle=True):
|
132 |
+
self.annotation_lines = annotation_lines
|
133 |
+
self.class_name_path = class_name_path
|
134 |
+
self.num_classes = len([line.strip() for line in open(class_name_path).readlines()])
|
135 |
+
self.num_gpu = yolo_config['num_gpu']
|
136 |
+
self.batch_size = yolo_config['batch_size'] * self.num_gpu
|
137 |
+
self.target_img_size = yolo_config['img_size']
|
138 |
+
self.anchors = np.array(yolo_config['anchors']).reshape((9, 2))
|
139 |
+
self.shuffle = shuffle
|
140 |
+
self.indexes = np.arange(len(self.annotation_lines))
|
141 |
+
self.folder_path = folder_path
|
142 |
+
self.max_boxes = max_boxes
|
143 |
+
self.on_epoch_end()
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
'number of batches per epoch'
|
147 |
+
return int(np.ceil(len(self.annotation_lines) / self.batch_size))
|
148 |
+
|
149 |
+
def __getitem__(self, index):
|
150 |
+
'Generate one batch of data'
|
151 |
+
|
152 |
+
# Generate indexes of the batch
|
153 |
+
idxs = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
|
154 |
+
|
155 |
+
# Find list of IDs
|
156 |
+
lines = [self.annotation_lines[i] for i in idxs]
|
157 |
+
|
158 |
+
# Generate data
|
159 |
+
X, y_tensor, y_bbox = self.__data_generation(lines)
|
160 |
+
|
161 |
+
return [X, *y_tensor, y_bbox], np.zeros(len(lines))
|
162 |
+
|
163 |
+
def on_epoch_end(self):
|
164 |
+
'Updates indexes after each epoch'
|
165 |
+
if self.shuffle:
|
166 |
+
np.random.shuffle(self.indexes)
|
167 |
+
|
168 |
+
def __data_generation(self, annotation_lines):
|
169 |
+
"""
|
170 |
+
Generates data containing batch_size samples
|
171 |
+
:param annotation_lines:
|
172 |
+
:return:
|
173 |
+
"""
|
174 |
+
|
175 |
+
X = np.empty((len(annotation_lines), *self.target_img_size), dtype=np.float32)
|
176 |
+
y_bbox = np.empty((len(annotation_lines), self.max_boxes, 5), dtype=np.float32) # x1y1x2y2
|
177 |
+
|
178 |
+
for i, line in enumerate(annotation_lines):
|
179 |
+
img_data, box_data = self.get_data(line)
|
180 |
+
X[i] = img_data
|
181 |
+
y_bbox[i] = box_data
|
182 |
+
|
183 |
+
y_tensor, y_true_boxes_xywh = preprocess_true_boxes(y_bbox, self.target_img_size[:2], self.anchors, self.num_classes)
|
184 |
+
|
185 |
+
return X, y_tensor, y_true_boxes_xywh
|
186 |
+
|
187 |
+
def get_data(self, annotation_line):
|
188 |
+
line = annotation_line.split()
|
189 |
+
img_path = line[0]
|
190 |
+
img = cv2.imread(os.path.join(self.folder_path, img_path))[:, :, ::-1]
|
191 |
+
ih, iw = img.shape[:2]
|
192 |
+
h, w, c = self.target_img_size
|
193 |
+
boxes = np.array([np.array(list(map(float, box.split(',')))) for box in line[1:]], dtype=np.float32) # x1y1x2y2
|
194 |
+
scale_w, scale_h = w / iw, h / ih
|
195 |
+
img = cv2.resize(img, (w, h))
|
196 |
+
image_data = np.array(img) / 255.
|
197 |
+
|
198 |
+
# correct boxes coordinates
|
199 |
+
box_data = np.zeros((self.max_boxes, 5))
|
200 |
+
if len(boxes) > 0:
|
201 |
+
np.random.shuffle(boxes)
|
202 |
+
boxes = boxes[:self.max_boxes]
|
203 |
+
boxes[:, [0, 2]] = boxes[:, [0, 2]] * scale_w # + dx
|
204 |
+
boxes[:, [1, 3]] = boxes[:, [1, 3]] * scale_h # + dy
|
205 |
+
box_data[:len(boxes)] = boxes
|
206 |
+
|
207 |
+
return image_data, box_data
|
208 |
+
|
209 |
+
|
210 |
+
def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes):
|
211 |
+
'''Preprocess true boxes to training input format
|
212 |
+
|
213 |
+
Parameters
|
214 |
+
----------
|
215 |
+
true_boxes: array, shape=(bs, max boxes per img, 5)
|
216 |
+
Absolute x_min, y_min, x_max, y_max, class_id relative to input_shape.
|
217 |
+
input_shape: array-like, hw, multiples of 32
|
218 |
+
anchors: array, shape=(N, 2), (9, wh)
|
219 |
+
num_classes: int
|
220 |
+
|
221 |
+
Returns
|
222 |
+
-------
|
223 |
+
y_true: list of array, shape like yolo_outputs, xywh are reletive value
|
224 |
+
|
225 |
+
'''
|
226 |
+
|
227 |
+
num_stages = 3 # default setting for yolo, tiny yolo will be 2
|
228 |
+
anchor_mask = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
229 |
+
bbox_per_grid = 3
|
230 |
+
true_boxes = np.array(true_boxes, dtype='float32')
|
231 |
+
true_boxes_abs = np.array(true_boxes, dtype='float32')
|
232 |
+
input_shape = np.array(input_shape, dtype='int32')
|
233 |
+
true_boxes_xy = (true_boxes_abs[..., 0:2] + true_boxes_abs[..., 2:4]) // 2 # (100, 2)
|
234 |
+
true_boxes_wh = true_boxes_abs[..., 2:4] - true_boxes_abs[..., 0:2] # (100, 2)
|
235 |
+
|
236 |
+
# Normalize x,y,w, h, relative to img size -> (0~1)
|
237 |
+
true_boxes[..., 0:2] = true_boxes_xy/input_shape[::-1] # xy
|
238 |
+
true_boxes[..., 2:4] = true_boxes_wh/input_shape[::-1] # wh
|
239 |
+
|
240 |
+
bs = true_boxes.shape[0]
|
241 |
+
grid_sizes = [input_shape//{0:8, 1:16, 2:32}[stage] for stage in range(num_stages)]
|
242 |
+
y_true = [np.zeros((bs,
|
243 |
+
grid_sizes[s][0],
|
244 |
+
grid_sizes[s][1],
|
245 |
+
bbox_per_grid,
|
246 |
+
5+num_classes), dtype='float32')
|
247 |
+
for s in range(num_stages)]
|
248 |
+
# [(?, 52, 52, 3, 5+num_classes) (?, 26, 26, 3, 5+num_classes) (?, 13, 13, 3, 5+num_classes) ]
|
249 |
+
y_true_boxes_xywh = np.concatenate((true_boxes_xy, true_boxes_wh), axis=-1)
|
250 |
+
# Expand dim to apply broadcasting.
|
251 |
+
anchors = np.expand_dims(anchors, 0) # (1, 9 , 2)
|
252 |
+
anchor_maxes = anchors / 2. # (1, 9 , 2)
|
253 |
+
anchor_mins = -anchor_maxes # (1, 9 , 2)
|
254 |
+
valid_mask = true_boxes_wh[..., 0] > 0 # (1, 100)
|
255 |
+
|
256 |
+
for batch_idx in range(bs):
|
257 |
+
# Discard zero rows.
|
258 |
+
wh = true_boxes_wh[batch_idx, valid_mask[batch_idx]] # (# of bbox, 2)
|
259 |
+
num_boxes = len(wh)
|
260 |
+
if num_boxes == 0: continue
|
261 |
+
wh = np.expand_dims(wh, -2) # (# of bbox, 1, 2)
|
262 |
+
box_maxes = wh / 2. # (# of bbox, 1, 2)
|
263 |
+
box_mins = -box_maxes # (# of bbox, 1, 2)
|
264 |
+
|
265 |
+
# Compute IoU between each anchors and true boxes for responsibility assignment
|
266 |
+
intersect_mins = np.maximum(box_mins, anchor_mins) # (# of bbox, 9, 2)
|
267 |
+
intersect_maxes = np.minimum(box_maxes, anchor_maxes)
|
268 |
+
intersect_wh = np.maximum(intersect_maxes - intersect_mins, 0.)
|
269 |
+
intersect_area = np.prod(intersect_wh, axis=-1) # (9,)
|
270 |
+
box_area = wh[..., 0] * wh[..., 1] # (# of bbox, 1)
|
271 |
+
anchor_area = anchors[..., 0] * anchors[..., 1] # (1, 9)
|
272 |
+
iou = intersect_area / (box_area + anchor_area - intersect_area) # (# of bbox, 9)
|
273 |
+
|
274 |
+
# Find best anchor for each true box
|
275 |
+
best_anchors = np.argmax(iou, axis=-1) # (# of bbox,)
|
276 |
+
for box_idx in range(num_boxes):
|
277 |
+
best_anchor = best_anchors[box_idx]
|
278 |
+
for stage in range(num_stages):
|
279 |
+
if best_anchor in anchor_mask[stage]:
|
280 |
+
x_offset = true_boxes[batch_idx, box_idx, 0]*grid_sizes[stage][1]
|
281 |
+
y_offset = true_boxes[batch_idx, box_idx, 1]*grid_sizes[stage][0]
|
282 |
+
# Grid Index
|
283 |
+
grid_col = np.floor(x_offset).astype('int32')
|
284 |
+
grid_row = np.floor(y_offset).astype('int32')
|
285 |
+
anchor_idx = anchor_mask[stage].index(best_anchor)
|
286 |
+
class_idx = true_boxes[batch_idx, box_idx, 4].astype('int32')
|
287 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 0] = x_offset - grid_col # x
|
288 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 1] = y_offset - grid_row # y
|
289 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, :4] = true_boxes_abs[batch_idx, box_idx, :4] # abs xywh
|
290 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, :2] = true_boxes_xy[batch_idx, box_idx, :] # abs xy
|
291 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 2:4] = true_boxes_wh[batch_idx, box_idx, :] # abs wh
|
292 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 4] = 1 # confidence
|
293 |
+
|
294 |
+
y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 5+class_idx] = 1 # one-hot encoding
|
295 |
+
# smooth
|
296 |
+
# onehot = np.zeros(num_classes, dtype=np.float)
|
297 |
+
# onehot[class_idx] = 1.0
|
298 |
+
# uniform_distribution = np.full(num_classes, 1.0 / num_classes)
|
299 |
+
# delta = 0.01
|
300 |
+
# smooth_onehot = onehot * (1 - delta) + delta * uniform_distribution
|
301 |
+
# y_true[stage][batch_idx, grid_row, grid_col, anchor_idx, 5:] = smooth_onehot
|
302 |
+
|
303 |
+
return y_true, y_true_boxes_xywh
|
304 |
+
|
305 |
+
"""
|
306 |
+
Calculate the AP given the recall and precision array
|
307 |
+
1st) We compute a version of the measured precision/recall curve with
|
308 |
+
precision monotonically decreasing
|
309 |
+
2nd) We compute the AP as the area under this curve by numerical integration.
|
310 |
+
"""
|
311 |
+
def voc_ap(rec, prec):
|
312 |
+
"""
|
313 |
+
--- Official matlab code VOC2012---
|
314 |
+
mrec=[0 ; rec ; 1];
|
315 |
+
mpre=[0 ; prec ; 0];
|
316 |
+
for i=numel(mpre)-1:-1:1
|
317 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
318 |
+
end
|
319 |
+
i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
320 |
+
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
321 |
+
"""
|
322 |
+
rec.insert(0, 0.0) # insert 0.0 at begining of list
|
323 |
+
rec.append(1.0) # insert 1.0 at end of list
|
324 |
+
mrec = rec[:]
|
325 |
+
prec.insert(0, 0.0) # insert 0.0 at begining of list
|
326 |
+
prec.append(0.0) # insert 0.0 at end of list
|
327 |
+
mpre = prec[:]
|
328 |
+
"""
|
329 |
+
This part makes the precision monotonically decreasing
|
330 |
+
(goes from the end to the beginning)
|
331 |
+
matlab: for i=numel(mpre)-1:-1:1
|
332 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
333 |
+
"""
|
334 |
+
# matlab indexes start in 1 but python in 0, so I have to do:
|
335 |
+
# range(start=(len(mpre) - 2), end=0, step=-1)
|
336 |
+
# also the python function range excludes the end, resulting in:
|
337 |
+
# range(start=(len(mpre) - 2), end=-1, step=-1)
|
338 |
+
for i in range(len(mpre)-2, -1, -1):
|
339 |
+
mpre[i] = max(mpre[i], mpre[i+1])
|
340 |
+
"""
|
341 |
+
This part creates a list of indexes where the recall changes
|
342 |
+
matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
343 |
+
"""
|
344 |
+
i_list = []
|
345 |
+
for i in range(1, len(mrec)):
|
346 |
+
if mrec[i] != mrec[i-1]:
|
347 |
+
i_list.append(i) # if it was matlab would be i + 1
|
348 |
+
"""
|
349 |
+
The Average Precision (AP) is the area under the curve
|
350 |
+
(numerical integration)
|
351 |
+
matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
352 |
+
"""
|
353 |
+
ap = 0.0
|
354 |
+
for i in i_list:
|
355 |
+
ap += ((mrec[i]-mrec[i-1])*mpre[i])
|
356 |
+
return ap, mrec, mpre
|
357 |
+
|
358 |
+
"""
|
359 |
+
Draw plot using Matplotlib
|
360 |
+
"""
|
361 |
+
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
|
362 |
+
# sort the dictionary by decreasing value, into a list of tuples
|
363 |
+
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
|
364 |
+
print(sorted_dic_by_value)
|
365 |
+
# unpacking the list of tuples into two lists
|
366 |
+
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
|
367 |
+
#
|
368 |
+
if true_p_bar != "":
|
369 |
+
"""
|
370 |
+
Special case to draw in:
|
371 |
+
- green -> TP: True Positives (object detected and matches ground-truth)
|
372 |
+
- red -> FP: False Positives (object detected but does not match ground-truth)
|
373 |
+
- pink -> FN: False Negatives (object not detected but present in the ground-truth)
|
374 |
+
"""
|
375 |
+
fp_sorted = []
|
376 |
+
tp_sorted = []
|
377 |
+
for key in sorted_keys:
|
378 |
+
fp_sorted.append(dictionary[key] - true_p_bar[key])
|
379 |
+
tp_sorted.append(true_p_bar[key])
|
380 |
+
plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
|
381 |
+
plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
|
382 |
+
# add legend
|
383 |
+
plt.legend(loc='lower right')
|
384 |
+
"""
|
385 |
+
Write number on side of bar
|
386 |
+
"""
|
387 |
+
fig = plt.gcf() # gcf - get current figure
|
388 |
+
axes = plt.gca()
|
389 |
+
r = fig.canvas.get_renderer()
|
390 |
+
for i, val in enumerate(sorted_values):
|
391 |
+
fp_val = fp_sorted[i]
|
392 |
+
tp_val = tp_sorted[i]
|
393 |
+
fp_str_val = " " + str(fp_val)
|
394 |
+
tp_str_val = fp_str_val + " " + str(tp_val)
|
395 |
+
# trick to paint multicolor with offset:
|
396 |
+
# first paint everything and then repaint the first number
|
397 |
+
t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
|
398 |
+
plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
|
399 |
+
if i == (len(sorted_values)-1): # largest bar
|
400 |
+
adjust_axes(r, t, fig, axes)
|
401 |
+
else:
|
402 |
+
plt.barh(range(n_classes), sorted_values, color=plot_color)
|
403 |
+
"""
|
404 |
+
Write number on side of bar
|
405 |
+
"""
|
406 |
+
fig = plt.gcf() # gcf - get current figure
|
407 |
+
axes = plt.gca()
|
408 |
+
r = fig.canvas.get_renderer()
|
409 |
+
for i, val in enumerate(sorted_values):
|
410 |
+
str_val = " " + str(val) # add a space before
|
411 |
+
if val < 1.0:
|
412 |
+
str_val = " {0:.2f}".format(val)
|
413 |
+
t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
|
414 |
+
# re-set axes to show number inside the figure
|
415 |
+
if i == (len(sorted_values)-1): # largest bar
|
416 |
+
adjust_axes(r, t, fig, axes)
|
417 |
+
# set window title
|
418 |
+
fig.canvas.set_window_title(window_title)
|
419 |
+
# write classes in y axis
|
420 |
+
tick_font_size = 12
|
421 |
+
plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
|
422 |
+
"""
|
423 |
+
Re-scale height accordingly
|
424 |
+
"""
|
425 |
+
init_height = fig.get_figheight()
|
426 |
+
# comput the matrix height in points and inches
|
427 |
+
dpi = fig.dpi
|
428 |
+
height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
|
429 |
+
height_in = height_pt / dpi
|
430 |
+
# compute the required figure height
|
431 |
+
top_margin = 0.15 # in percentage of the figure height
|
432 |
+
bottom_margin = 0.05 # in percentage of the figure height
|
433 |
+
figure_height = height_in / (1 - top_margin - bottom_margin)
|
434 |
+
# set new height
|
435 |
+
if figure_height > init_height:
|
436 |
+
fig.set_figheight(figure_height)
|
437 |
+
|
438 |
+
# set plot title
|
439 |
+
plt.title(plot_title, fontsize=14)
|
440 |
+
# set axis titles
|
441 |
+
# plt.xlabel('classes')
|
442 |
+
plt.xlabel(x_label, fontsize='large')
|
443 |
+
# adjust size of window
|
444 |
+
fig.tight_layout()
|
445 |
+
# save the plot
|
446 |
+
fig.savefig(output_path)
|
447 |
+
# show image
|
448 |
+
# if to_show:
|
449 |
+
plt.show()
|
450 |
+
# close the plot
|
451 |
+
# plt.close()
|
452 |
+
|
453 |
+
"""
|
454 |
+
Plot - adjust axes
|
455 |
+
"""
|
456 |
+
def adjust_axes(r, t, fig, axes):
|
457 |
+
# get text width for re-scaling
|
458 |
+
bb = t.get_window_extent(renderer=r)
|
459 |
+
text_width_inches = bb.width / fig.dpi
|
460 |
+
# get axis width in inches
|
461 |
+
current_fig_width = fig.get_figwidth()
|
462 |
+
new_fig_width = current_fig_width + text_width_inches
|
463 |
+
propotion = new_fig_width / current_fig_width
|
464 |
+
# get axis limit
|
465 |
+
x_lim = axes.get_xlim()
|
466 |
+
axes.set_xlim([x_lim[0], x_lim[1]*propotion])
|
467 |
+
|
468 |
+
|
469 |
+
def read_txt_to_list(path):
|
470 |
+
# open txt file lines to a list
|
471 |
+
with open(path) as f:
|
472 |
+
content = f.readlines()
|
473 |
+
# remove whitespace characters like `\n` at the end of each line
|
474 |
+
content = [x.strip() for x in content]
|
475 |
+
return content
|
xml_to_txt.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import xml.etree.ElementTree as ET
|
2 |
+
import os
|
3 |
+
from glob import glob
|
4 |
+
|
5 |
+
XML_PATH = './dataset/xml'
|
6 |
+
CLASSES_PATH = './class_names/classes.txt'
|
7 |
+
TXT_PATH = './dataset/txt/anno.txt'
|
8 |
+
|
9 |
+
|
10 |
+
'''loads the classes'''
|
11 |
+
def get_classes(classes_path):
|
12 |
+
with open(classes_path) as f:
|
13 |
+
class_names = f.readlines()
|
14 |
+
class_names = [c.strip() for c in class_names]
|
15 |
+
return class_names
|
16 |
+
|
17 |
+
|
18 |
+
classes = get_classes(CLASSES_PATH)
|
19 |
+
assert len(classes) > 0, 'no class names detected!'
|
20 |
+
print(f'num classes: {len(classes)}')
|
21 |
+
|
22 |
+
# output file
|
23 |
+
list_file = open(TXT_PATH, 'w')
|
24 |
+
|
25 |
+
for path in glob(os.path.join(XML_PATH, '*.xml')):
|
26 |
+
in_file = open(path)
|
27 |
+
|
28 |
+
# Parse .xml file
|
29 |
+
tree = ET.parse(in_file)
|
30 |
+
root = tree.getroot()
|
31 |
+
# Write object information to .txt file
|
32 |
+
file_name = root.find('filename').text
|
33 |
+
print(file_name)
|
34 |
+
list_file.write(file_name)
|
35 |
+
for obj in root.iter('object'):
|
36 |
+
cls = obj.find('name').text
|
37 |
+
cls_id = classes.index(cls)
|
38 |
+
xmlbox = obj.find('bndbox')
|
39 |
+
b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
|
40 |
+
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
|
41 |
+
list_file.write('\n')
|
42 |
+
list_file.close()
|
yolov4.weights
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8a4f6c62188738d86dc6898d82724ec0964d0eb9d2ae0f0a9d53d65d108d562
|
3 |
+
size 257717640
|