glenn-jocher commited on
Commit
fab5085
1 Parent(s): fd96810

EMA bug fix 2 (#2330)

Browse files

* EMA bug fix 2

* update

Files changed (4) hide show
  1. hubconf.py +1 -1
  2. models/experimental.py +2 -1
  3. train.py +5 -5
  4. utils/general.py +5 -3
hubconf.py CHANGED
@@ -120,7 +120,7 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
120
  """
121
  model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
122
  if isinstance(model, dict):
123
- model = model['model'] # load model
124
 
125
  hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
126
  hub_model.load_state_dict(model.float().state_dict()) # load state_dict
 
120
  """
121
  model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
122
  if isinstance(model, dict):
123
+ model = model['ema' if model.get('ema') else 'model'] # load model
124
 
125
  hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
126
  hub_model.load_state_dict(model.float().state_dict()) # load state_dict
models/experimental.py CHANGED
@@ -115,7 +115,8 @@ def attempt_load(weights, map_location=None):
115
  model = Ensemble()
116
  for w in weights if isinstance(weights, list) else [weights]:
117
  attempt_download(w)
118
- model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
 
119
 
120
  # Compatibility updates
121
  for m in model.modules():
 
115
  model = Ensemble()
116
  for w in weights if isinstance(weights, list) else [weights]:
117
  attempt_download(w)
118
+ ckpt = torch.load(w, map_location=map_location) # load
119
+ model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
120
 
121
  # Compatibility updates
122
  for m in model.modules():
train.py CHANGED
@@ -151,8 +151,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
151
 
152
  # EMA
153
  if ema and ckpt.get('ema'):
154
- ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
155
- ema.updates = ckpt['ema'][1]
156
 
157
  # Results
158
  if ckpt.get('training_results') is not None:
@@ -383,9 +383,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
383
  ckpt = {'epoch': epoch,
384
  'best_fitness': best_fitness,
385
  'training_results': results_file.read_text(),
386
- 'model': ema.ema if final_epoch else deepcopy(
387
- model.module if is_parallel(model) else model).half(),
388
- 'ema': (deepcopy(ema.ema).half(), ema.updates),
389
  'optimizer': optimizer.state_dict(),
390
  'wandb_id': wandb_run.id if wandb else None}
391
 
 
151
 
152
  # EMA
153
  if ema and ckpt.get('ema'):
154
+ ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
155
+ ema.updates = ckpt['updates']
156
 
157
  # Results
158
  if ckpt.get('training_results') is not None:
 
383
  ckpt = {'epoch': epoch,
384
  'best_fitness': best_fitness,
385
  'training_results': results_file.read_text(),
386
+ 'model': deepcopy(model.module if is_parallel(model) else model).half(),
387
+ 'ema': deepcopy(ema.ema).half(),
388
+ 'updates': ema.updates,
389
  'optimizer': optimizer.state_dict(),
390
  'wandb_id': wandb_run.id if wandb else None}
391
 
utils/general.py CHANGED
@@ -481,10 +481,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
481
  return output
482
 
483
 
484
- def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
485
  # Strip optimizer from 'f' to finalize training, optionally save as 's'
486
  x = torch.load(f, map_location=torch.device('cpu'))
487
- for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
 
 
488
  x[k] = None
489
  x['epoch'] = -1
490
  x['model'].half() # to FP16
@@ -492,7 +494,7 @@ def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *;
492
  p.requires_grad = False
493
  torch.save(x, s or f)
494
  mb = os.path.getsize(s or f) / 1E6 # filesize
495
- print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
496
 
497
 
498
  def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
 
481
  return output
482
 
483
 
484
+ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
485
  # Strip optimizer from 'f' to finalize training, optionally save as 's'
486
  x = torch.load(f, map_location=torch.device('cpu'))
487
+ if x.get('ema'):
488
+ x['model'] = x['ema'] # replace model with ema
489
+ for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
490
  x[k] = None
491
  x['epoch'] = -1
492
  x['model'].half() # to FP16
 
494
  p.requires_grad = False
495
  torch.save(x, s or f)
496
  mb = os.path.getsize(s or f) / 1E6 # filesize
497
+ print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
498
 
499
 
500
  def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):