glenn-jocher
commited on
Commit
•
c1a2a7a
1
Parent(s):
8074745
hyperparameter evolution bug fix (#566)
Browse files- train.py +2 -2
- utils/utils.py +16 -12
train.py
CHANGED
@@ -465,7 +465,7 @@ if __name__ == '__main__':
|
|
465 |
# Evolve hyperparameters (optional)
|
466 |
else:
|
467 |
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
|
468 |
-
meta = {'lr0': (1, 1e-5, 1e-
|
469 |
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
|
470 |
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
|
471 |
'giou': (1, 0.02, 0.2), # GIoU loss gain
|
@@ -534,6 +534,6 @@ if __name__ == '__main__':
|
|
534 |
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
|
535 |
|
536 |
# Plot results
|
537 |
-
|
538 |
print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
|
539 |
'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))
|
|
|
465 |
# Evolve hyperparameters (optional)
|
466 |
else:
|
467 |
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
|
468 |
+
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
|
469 |
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
|
470 |
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
|
471 |
'giou': (1, 0.02, 0.2), # GIoU loss gain
|
|
|
534 |
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
|
535 |
|
536 |
# Plot results
|
537 |
+
plot_evolution(yaml_file)
|
538 |
print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
|
539 |
'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))
|
utils/utils.py
CHANGED
@@ -919,6 +919,15 @@ def increment_dir(dir, comment=''):
|
|
919 |
|
920 |
|
921 |
# Plotting functions ---------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
922 |
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
923 |
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
|
924 |
def butter_lowpass(cutoff, fs, order):
|
@@ -1130,13 +1139,6 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st
|
|
1130 |
|
1131 |
def plot_labels(labels, save_dir=''):
|
1132 |
# plot dataset labels
|
1133 |
-
def hist2d(x, y, n=100):
|
1134 |
-
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
|
1135 |
-
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
|
1136 |
-
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
|
1137 |
-
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
|
1138 |
-
return np.log(hist[xidx, yidx])
|
1139 |
-
|
1140 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
1141 |
nc = int(c.max() + 1) # number of classes
|
1142 |
|
@@ -1154,23 +1156,25 @@ def plot_labels(labels, save_dir=''):
|
|
1154 |
plt.close()
|
1155 |
|
1156 |
|
1157 |
-
def
|
1158 |
# Plot hyperparameter evolution results in evolve.txt
|
1159 |
with open(yaml_file) as f:
|
1160 |
hyp = yaml.load(f, Loader=yaml.FullLoader)
|
1161 |
x = np.loadtxt('evolve.txt', ndmin=2)
|
1162 |
f = fitness(x)
|
1163 |
# weights = (f - f.min()) ** 2 # for weighted results
|
1164 |
-
plt.figure(figsize=(
|
1165 |
matplotlib.rc('font', **{'size': 8})
|
1166 |
for i, (k, v) in enumerate(hyp.items()):
|
1167 |
y = x[:, i + 7]
|
1168 |
# mu = (y * weights).sum() / weights.sum() # best weighted result
|
1169 |
mu = y[f.argmax()] # best single result
|
1170 |
-
plt.subplot(
|
1171 |
-
plt.
|
1172 |
-
plt.plot(
|
1173 |
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
|
|
|
|
|
1174 |
print('%15s: %.3g' % (k, mu))
|
1175 |
plt.savefig('evolve.png', dpi=200)
|
1176 |
print('\nPlot saved as evolve.png')
|
|
|
919 |
|
920 |
|
921 |
# Plotting functions ---------------------------------------------------------------------------------------------------
|
922 |
+
def hist2d(x, y, n=100):
|
923 |
+
# 2d histogram used in labels.png and evolve.png
|
924 |
+
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
|
925 |
+
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
|
926 |
+
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
|
927 |
+
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
|
928 |
+
return np.log(hist[xidx, yidx])
|
929 |
+
|
930 |
+
|
931 |
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
932 |
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
|
933 |
def butter_lowpass(cutoff, fs, order):
|
|
|
1139 |
|
1140 |
def plot_labels(labels, save_dir=''):
|
1141 |
# plot dataset labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1142 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
1143 |
nc = int(c.max() + 1) # number of classes
|
1144 |
|
|
|
1156 |
plt.close()
|
1157 |
|
1158 |
|
1159 |
+
def plot_evolution(yaml_file='runs/evolve/hyp_evolved.yaml'): # from utils.utils import *; plot_evolution()
|
1160 |
# Plot hyperparameter evolution results in evolve.txt
|
1161 |
with open(yaml_file) as f:
|
1162 |
hyp = yaml.load(f, Loader=yaml.FullLoader)
|
1163 |
x = np.loadtxt('evolve.txt', ndmin=2)
|
1164 |
f = fitness(x)
|
1165 |
# weights = (f - f.min()) ** 2 # for weighted results
|
1166 |
+
plt.figure(figsize=(10, 10), tight_layout=True)
|
1167 |
matplotlib.rc('font', **{'size': 8})
|
1168 |
for i, (k, v) in enumerate(hyp.items()):
|
1169 |
y = x[:, i + 7]
|
1170 |
# mu = (y * weights).sum() / weights.sum() # best weighted result
|
1171 |
mu = y[f.argmax()] # best single result
|
1172 |
+
plt.subplot(5, 5, i + 1)
|
1173 |
+
plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
|
1174 |
+
plt.plot(mu, f.max(), 'k+', markersize=15)
|
1175 |
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
|
1176 |
+
if i % 5 != 0:
|
1177 |
+
plt.yticks([])
|
1178 |
print('%15s: %.3g' % (k, mu))
|
1179 |
plt.savefig('evolve.png', dpi=200)
|
1180 |
print('\nPlot saved as evolve.png')
|