glenn-jocher commited on
Commit
f010147
1 Parent(s): 784feae

Update matplotlib.use('Agg') tight (#1583)

Browse files

* Update matplotlib tight_layout=True

* udpate

* udpate

* update

* png to ps

* update

* update

Files changed (3) hide show
  1. utils/autoanchor.py +1 -2
  2. utils/metrics.py +2 -4
  3. utils/plots.py +10 -10
utils/autoanchor.py CHANGED
@@ -124,13 +124,12 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
124
  # k, d = [None] * 20, [None] * 20
125
  # for i in tqdm(range(1, 21)):
126
  # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
127
- # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
128
  # ax = ax.ravel()
129
  # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
130
  # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
131
  # ax[0].hist(wh[wh[:, 0]<100, 0],400)
132
  # ax[1].hist(wh[wh[:, 1]<100, 1],400)
133
- # fig.tight_layout()
134
  # fig.savefig('wh.png', dpi=200)
135
 
136
  # Evolve
 
124
  # k, d = [None] * 20, [None] * 20
125
  # for i in tqdm(range(1, 21)):
126
  # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
127
+ # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
128
  # ax = ax.ravel()
129
  # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
130
  # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
131
  # ax[0].hist(wh[wh[:, 0]<100, 0],400)
132
  # ax[1].hist(wh[wh[:, 1]<100, 1],400)
 
133
  # fig.savefig('wh.png', dpi=200)
134
 
135
  # Evolve
utils/metrics.py CHANGED
@@ -163,7 +163,7 @@ class ConfusionMatrix:
163
  array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
164
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
165
 
166
- fig = plt.figure(figsize=(12, 9))
167
  sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
168
  labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
169
  sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
@@ -171,7 +171,6 @@ class ConfusionMatrix:
171
  yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
172
  fig.axes[0].set_xlabel('True')
173
  fig.axes[0].set_ylabel('Predicted')
174
- fig.tight_layout()
175
  fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
176
  except Exception as e:
177
  pass
@@ -184,7 +183,7 @@ class ConfusionMatrix:
184
  # Plots ----------------------------------------------------------------------------------------------------------------
185
 
186
  def plot_pr_curve(px, py, ap, save_dir='.', names=()):
187
- fig, ax = plt.subplots(1, 1, figsize=(9, 6))
188
  py = np.stack(py, axis=1)
189
 
190
  if 0 < len(names) < 21: # show mAP in legend if < 10 classes
@@ -199,5 +198,4 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
199
  ax.set_xlim(0, 1)
200
  ax.set_ylim(0, 1)
201
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
202
- fig.tight_layout()
203
  fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
 
163
  array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
164
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
165
 
166
+ fig = plt.figure(figsize=(12, 9), tight_layout=True)
167
  sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
168
  labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
169
  sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
 
171
  yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
172
  fig.axes[0].set_xlabel('True')
173
  fig.axes[0].set_ylabel('Predicted')
 
174
  fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
175
  except Exception as e:
176
  pass
 
183
  # Plots ----------------------------------------------------------------------------------------------------------------
184
 
185
  def plot_pr_curve(px, py, ap, save_dir='.', names=()):
186
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
187
  py = np.stack(py, axis=1)
188
 
189
  if 0 < len(names) < 21: # show mAP in legend if < 10 classes
 
198
  ax.set_xlim(0, 1)
199
  ax.set_ylim(0, 1)
200
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
 
201
  fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
utils/plots.py CHANGED
@@ -21,7 +21,7 @@ from utils.metrics import fitness
21
 
22
  # Settings
23
  matplotlib.rc('font', **{'size': 11})
24
- matplotlib.use('svg') # for writing to files only
25
 
26
 
27
  def color_list():
@@ -73,7 +73,7 @@ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
73
  ya = np.exp(x)
74
  yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
75
 
76
- fig = plt.figure(figsize=(6, 3), dpi=150)
77
  plt.plot(x, ya, '.-', label='YOLOv3')
78
  plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
79
  plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
@@ -83,7 +83,6 @@ def plot_wh_methods(): # from utils.plots import *; plot_wh_methods()
83
  plt.ylabel('output')
84
  plt.grid()
85
  plt.legend()
86
- fig.tight_layout()
87
  fig.savefig('comparison.png', dpi=200)
88
 
89
 
@@ -145,7 +144,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
145
  if boxes.max() <= 1: # if normalized
146
  boxes[[0, 2]] *= w # scale to pixels
147
  boxes[[1, 3]] *= h
148
- elif scale_factor < 1: # absolute coords need scale if image scales
149
  boxes *= scale_factor
150
  boxes[[0, 2]] += block_x
151
  boxes[[1, 3]] += block_y
@@ -188,7 +187,6 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
188
  plt.grid()
189
  plt.xlim(0, epochs)
190
  plt.ylim(0)
191
- plt.tight_layout()
192
  plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
193
 
194
 
@@ -267,12 +265,13 @@ def plot_labels(labels, save_dir=Path(''), loggers=None):
267
  sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
268
  plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
269
  diag_kws=dict(bins=50))
270
- plt.savefig(save_dir / 'labels_correlogram.png', dpi=200)
271
  plt.close()
272
  except Exception as e:
273
  pass
274
 
275
  # matplotlib labels
 
276
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
277
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
278
  ax[0].set_xlabel('classes')
@@ -295,13 +294,15 @@ def plot_labels(labels, save_dir=Path(''), loggers=None):
295
  for a in [0, 1, 2, 3]:
296
  for s in ['top', 'right', 'left', 'bottom']:
297
  ax[a].spines[s].set_visible(False)
298
- plt.savefig(save_dir / 'labels.png', dpi=200)
 
 
299
  plt.close()
300
 
301
  # loggers
302
  for k, v in loggers.items() or {}:
303
  if k == 'wandb' and v:
304
- v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})
305
 
306
 
307
  def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
@@ -353,7 +354,7 @@ def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_re
353
 
354
  def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
355
  # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
356
- fig, ax = plt.subplots(2, 5, figsize=(12, 6))
357
  ax = ax.ravel()
358
  s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
359
  'val Box', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95']
@@ -383,6 +384,5 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
383
  except Exception as e:
384
  print('Warning: Plotting error for %s; %s' % (f, e))
385
 
386
- fig.tight_layout()
387
  ax[1].legend()
388
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)
 
21
 
22
  # Settings
23
  matplotlib.rc('font', **{'size': 11})
24
+ matplotlib.use('Agg') # for writing to files only
25
 
26
 
27
  def color_list():
 
73
  ya = np.exp(x)
74
  yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
75
 
76
+ fig = plt.figure(figsize=(6, 3), tight_layout=True)
77
  plt.plot(x, ya, '.-', label='YOLOv3')
78
  plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
79
  plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
 
83
  plt.ylabel('output')
84
  plt.grid()
85
  plt.legend()
 
86
  fig.savefig('comparison.png', dpi=200)
87
 
88
 
 
144
  if boxes.max() <= 1: # if normalized
145
  boxes[[0, 2]] *= w # scale to pixels
146
  boxes[[1, 3]] *= h
147
+ elif scale_factor < 1: # absolute coords need scale if image scales
148
  boxes *= scale_factor
149
  boxes[[0, 2]] += block_x
150
  boxes[[1, 3]] += block_y
 
187
  plt.grid()
188
  plt.xlim(0, epochs)
189
  plt.ylim(0)
 
190
  plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
191
 
192
 
 
265
  sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
266
  plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
267
  diag_kws=dict(bins=50))
268
+ plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
269
  plt.close()
270
  except Exception as e:
271
  pass
272
 
273
  # matplotlib labels
274
+ matplotlib.use('svg') # faster
275
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
276
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
277
  ax[0].set_xlabel('classes')
 
294
  for a in [0, 1, 2, 3]:
295
  for s in ['top', 'right', 'left', 'bottom']:
296
  ax[a].spines[s].set_visible(False)
297
+
298
+ plt.savefig(save_dir / 'labels.jpg', dpi=200)
299
+ matplotlib.use('Agg')
300
  plt.close()
301
 
302
  # loggers
303
  for k, v in loggers.items() or {}:
304
  if k == 'wandb' and v:
305
+ v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]})
306
 
307
 
308
  def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
 
354
 
355
  def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
356
  # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
357
+ fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
358
  ax = ax.ravel()
359
  s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
360
  'val Box', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95']
 
384
  except Exception as e:
385
  print('Warning: Plotting error for %s; %s' % (f, e))
386
 
 
387
  ax[1].legend()
388
  fig.savefig(Path(save_dir) / 'results.png', dpi=200)