glenn-jocher commited on
Commit
1542cca
1 Parent(s): 9ef4760

Update labels.png with rectangles (#1432)

Browse files
Files changed (1) hide show
  1. utils/plots.py +21 -7
utils/plots.py CHANGED
@@ -13,7 +13,7 @@ import matplotlib.pyplot as plt
13
  import numpy as np
14
  import torch
15
  import yaml
16
- from PIL import Image
17
  from scipy.signal import butter, filtfilt
18
 
19
  from utils.general import xywh2xyxy, xyxy2xywh
@@ -266,17 +266,31 @@ def plot_labels(labels, save_dir=''):
266
  # plot dataset labels
267
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
268
  nc = int(c.max() + 1) # number of classes
 
269
 
270
  fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
271
  ax = ax.ravel()
272
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
273
  ax[0].set_xlabel('classes')
274
- ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
275
- ax[1].set_xlabel('x')
276
- ax[1].set_ylabel('y')
277
- ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
278
- ax[2].set_xlabel('width')
279
- ax[2].set_ylabel('height')
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
281
  plt.close()
282
 
 
13
  import numpy as np
14
  import torch
15
  import yaml
16
+ from PIL import Image, ImageDraw
17
  from scipy.signal import butter, filtfilt
18
 
19
  from utils.general import xywh2xyxy, xyxy2xywh
 
266
  # plot dataset labels
267
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
268
  nc = int(c.max() + 1) # number of classes
269
+ colors = color_list()
270
 
271
  fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
272
  ax = ax.ravel()
273
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
274
  ax[0].set_xlabel('classes')
275
+ ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
276
+ ax[2].set_xlabel('x')
277
+ ax[2].set_ylabel('y')
278
+ ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
279
+ ax[3].set_xlabel('width')
280
+ ax[3].set_ylabel('height')
281
+
282
+ # rectangles
283
+ labels[:, 1:3] = 0.5 # center
284
+ labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
285
+ img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
286
+ for cls, *box in labels[:1000]:
287
+ ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
288
+ ax[1].imshow(img)
289
+ ax[1].axis('off')
290
+
291
+ for a in [0, 1, 2, 3]:
292
+ for s in ['top', 'right', 'left', 'bottom']:
293
+ ax[a].spines[s].set_visible(False)
294
  plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
295
  plt.close()
296