glenn-jocher
commited on
Commit
•
ca5b10b
1
Parent(s):
0070995
Update train.py (#2290)
Browse files* Update train.py
* Update train.py
* Update train.py
* Update train.py
* Create train.py
train.py
CHANGED
@@ -146,8 +146,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
146 |
|
147 |
# Results
|
148 |
if ckpt.get('training_results') is not None:
|
149 |
-
|
150 |
-
file.write(ckpt['training_results']) # write results.txt
|
151 |
|
152 |
# Epochs
|
153 |
start_epoch = ckpt['epoch'] + 1
|
@@ -354,7 +353,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
354 |
|
355 |
# Write
|
356 |
with open(results_file, 'a') as f:
|
357 |
-
f.write(s + '%10.4g' * 7 % results + '\n') #
|
358 |
if len(opt.name) and opt.bucket:
|
359 |
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
|
360 |
|
@@ -375,15 +374,13 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
375 |
best_fitness = fi
|
376 |
|
377 |
# Save model
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
'optimizer': None if final_epoch else optimizer.state_dict(),
|
386 |
-
'wandb_id': wandb_run.id if wandb else None}
|
387 |
|
388 |
# Save last, best and delete
|
389 |
torch.save(ckpt, last)
|
@@ -396,9 +393,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
396 |
if rank in [-1, 0]:
|
397 |
# Strip optimizers
|
398 |
final = best if best.exists() else last # final model
|
399 |
-
for f in
|
400 |
if f.exists():
|
401 |
-
strip_optimizer(f)
|
402 |
if opt.bucket:
|
403 |
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
|
404 |
|
@@ -415,17 +412,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
415 |
# Test best.pt
|
416 |
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
417 |
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
|
418 |
-
for
|
419 |
results, _, _ = test.test(opt.data,
|
420 |
batch_size=batch_size * 2,
|
421 |
imgsz=imgsz_test,
|
422 |
-
conf_thres=
|
423 |
-
iou_thres=
|
424 |
-
model=attempt_load(
|
425 |
single_cls=opt.single_cls,
|
426 |
dataloader=testloader,
|
427 |
save_dir=save_dir,
|
428 |
-
save_json=
|
429 |
plots=False)
|
430 |
|
431 |
else:
|
|
|
146 |
|
147 |
# Results
|
148 |
if ckpt.get('training_results') is not None:
|
149 |
+
results_file.write_text(ckpt['training_results']) # write results.txt
|
|
|
150 |
|
151 |
# Epochs
|
152 |
start_epoch = ckpt['epoch'] + 1
|
|
|
353 |
|
354 |
# Write
|
355 |
with open(results_file, 'a') as f:
|
356 |
+
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
|
357 |
if len(opt.name) and opt.bucket:
|
358 |
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
|
359 |
|
|
|
374 |
best_fitness = fi
|
375 |
|
376 |
# Save model
|
377 |
+
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
|
378 |
+
ckpt = {'epoch': epoch,
|
379 |
+
'best_fitness': best_fitness,
|
380 |
+
'training_results': results_file.read_text(),
|
381 |
+
'model': ema.ema,
|
382 |
+
'optimizer': None if final_epoch else optimizer.state_dict(),
|
383 |
+
'wandb_id': wandb_run.id if wandb else None}
|
|
|
|
|
384 |
|
385 |
# Save last, best and delete
|
386 |
torch.save(ckpt, last)
|
|
|
393 |
if rank in [-1, 0]:
|
394 |
# Strip optimizers
|
395 |
final = best if best.exists() else last # final model
|
396 |
+
for f in last, best:
|
397 |
if f.exists():
|
398 |
+
strip_optimizer(f)
|
399 |
if opt.bucket:
|
400 |
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
|
401 |
|
|
|
412 |
# Test best.pt
|
413 |
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
414 |
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
|
415 |
+
for m in (last, best) if best.exists() else (last): # speed, mAP tests
|
416 |
results, _, _ = test.test(opt.data,
|
417 |
batch_size=batch_size * 2,
|
418 |
imgsz=imgsz_test,
|
419 |
+
conf_thres=0.001,
|
420 |
+
iou_thres=0.7,
|
421 |
+
model=attempt_load(m, device).half(),
|
422 |
single_cls=opt.single_cls,
|
423 |
dataloader=testloader,
|
424 |
save_dir=save_dir,
|
425 |
+
save_json=True,
|
426 |
plots=False)
|
427 |
|
428 |
else:
|