glenn-jocher
commited on
Commit
•
1542cca
1
Parent(s):
9ef4760
Update labels.png with rectangles (#1432)
Browse files- utils/plots.py +21 -7
utils/plots.py
CHANGED
@@ -13,7 +13,7 @@ import matplotlib.pyplot as plt
|
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
import yaml
|
16 |
-
from PIL import Image
|
17 |
from scipy.signal import butter, filtfilt
|
18 |
|
19 |
from utils.general import xywh2xyxy, xyxy2xywh
|
@@ -266,17 +266,31 @@ def plot_labels(labels, save_dir=''):
|
|
266 |
# plot dataset labels
|
267 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
268 |
nc = int(c.max() + 1) # number of classes
|
|
|
269 |
|
270 |
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
271 |
ax = ax.ravel()
|
272 |
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
273 |
ax[0].set_xlabel('classes')
|
274 |
-
ax[
|
275 |
-
ax[
|
276 |
-
ax[
|
277 |
-
ax[
|
278 |
-
ax[
|
279 |
-
ax[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
|
281 |
plt.close()
|
282 |
|
|
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
import yaml
|
16 |
+
from PIL import Image, ImageDraw
|
17 |
from scipy.signal import butter, filtfilt
|
18 |
|
19 |
from utils.general import xywh2xyxy, xyxy2xywh
|
|
|
266 |
# plot dataset labels
|
267 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
268 |
nc = int(c.max() + 1) # number of classes
|
269 |
+
colors = color_list()
|
270 |
|
271 |
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
272 |
ax = ax.ravel()
|
273 |
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
274 |
ax[0].set_xlabel('classes')
|
275 |
+
ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
|
276 |
+
ax[2].set_xlabel('x')
|
277 |
+
ax[2].set_ylabel('y')
|
278 |
+
ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
|
279 |
+
ax[3].set_xlabel('width')
|
280 |
+
ax[3].set_ylabel('height')
|
281 |
+
|
282 |
+
# rectangles
|
283 |
+
labels[:, 1:3] = 0.5 # center
|
284 |
+
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
|
285 |
+
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
|
286 |
+
for cls, *box in labels[:1000]:
|
287 |
+
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
|
288 |
+
ax[1].imshow(img)
|
289 |
+
ax[1].axis('off')
|
290 |
+
|
291 |
+
for a in [0, 1, 2, 3]:
|
292 |
+
for s in ['top', 'right', 'left', 'bottom']:
|
293 |
+
ax[a].spines[s].set_visible(False)
|
294 |
plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
|
295 |
plt.close()
|
296 |
|