glenn-jocher
commited on
Commit
•
fab5085
1
Parent(s):
fd96810
EMA bug fix 2 (#2330)
Browse files* EMA bug fix 2
* update
- hubconf.py +1 -1
- models/experimental.py +2 -1
- train.py +5 -5
- utils/general.py +5 -3
hubconf.py
CHANGED
@@ -120,7 +120,7 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
|
|
120 |
"""
|
121 |
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
|
122 |
if isinstance(model, dict):
|
123 |
-
model = model['model'] # load model
|
124 |
|
125 |
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
|
126 |
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
|
|
|
120 |
"""
|
121 |
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
|
122 |
if isinstance(model, dict):
|
123 |
+
model = model['ema' if model.get('ema') else 'model'] # load model
|
124 |
|
125 |
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
|
126 |
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
|
models/experimental.py
CHANGED
@@ -115,7 +115,8 @@ def attempt_load(weights, map_location=None):
|
|
115 |
model = Ensemble()
|
116 |
for w in weights if isinstance(weights, list) else [weights]:
|
117 |
attempt_download(w)
|
118 |
-
|
|
|
119 |
|
120 |
# Compatibility updates
|
121 |
for m in model.modules():
|
|
|
115 |
model = Ensemble()
|
116 |
for w in weights if isinstance(weights, list) else [weights]:
|
117 |
attempt_download(w)
|
118 |
+
ckpt = torch.load(w, map_location=map_location) # load
|
119 |
+
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
|
120 |
|
121 |
# Compatibility updates
|
122 |
for m in model.modules():
|
train.py
CHANGED
@@ -151,8 +151,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
151 |
|
152 |
# EMA
|
153 |
if ema and ckpt.get('ema'):
|
154 |
-
ema.ema.load_state_dict(ckpt['ema']
|
155 |
-
ema.updates = ckpt['
|
156 |
|
157 |
# Results
|
158 |
if ckpt.get('training_results') is not None:
|
@@ -383,9 +383,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
383 |
ckpt = {'epoch': epoch,
|
384 |
'best_fitness': best_fitness,
|
385 |
'training_results': results_file.read_text(),
|
386 |
-
'model':
|
387 |
-
|
388 |
-
'
|
389 |
'optimizer': optimizer.state_dict(),
|
390 |
'wandb_id': wandb_run.id if wandb else None}
|
391 |
|
|
|
151 |
|
152 |
# EMA
|
153 |
if ema and ckpt.get('ema'):
|
154 |
+
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
|
155 |
+
ema.updates = ckpt['updates']
|
156 |
|
157 |
# Results
|
158 |
if ckpt.get('training_results') is not None:
|
|
|
383 |
ckpt = {'epoch': epoch,
|
384 |
'best_fitness': best_fitness,
|
385 |
'training_results': results_file.read_text(),
|
386 |
+
'model': deepcopy(model.module if is_parallel(model) else model).half(),
|
387 |
+
'ema': deepcopy(ema.ema).half(),
|
388 |
+
'updates': ema.updates,
|
389 |
'optimizer': optimizer.state_dict(),
|
390 |
'wandb_id': wandb_run.id if wandb else None}
|
391 |
|
utils/general.py
CHANGED
@@ -481,10 +481,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|
481 |
return output
|
482 |
|
483 |
|
484 |
-
def strip_optimizer(f='
|
485 |
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
486 |
x = torch.load(f, map_location=torch.device('cpu'))
|
487 |
-
|
|
|
|
|
488 |
x[k] = None
|
489 |
x['epoch'] = -1
|
490 |
x['model'].half() # to FP16
|
@@ -492,7 +494,7 @@ def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *;
|
|
492 |
p.requires_grad = False
|
493 |
torch.save(x, s or f)
|
494 |
mb = os.path.getsize(s or f) / 1E6 # filesize
|
495 |
-
print(
|
496 |
|
497 |
|
498 |
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
|
|
|
481 |
return output
|
482 |
|
483 |
|
484 |
+
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
485 |
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
486 |
x = torch.load(f, map_location=torch.device('cpu'))
|
487 |
+
if x.get('ema'):
|
488 |
+
x['model'] = x['ema'] # replace model with ema
|
489 |
+
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
|
490 |
x[k] = None
|
491 |
x['epoch'] = -1
|
492 |
x['model'].half() # to FP16
|
|
|
494 |
p.requires_grad = False
|
495 |
torch.save(x, s or f)
|
496 |
mb = os.path.getsize(s or f) / 1E6 # filesize
|
497 |
+
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
|
498 |
|
499 |
|
500 |
def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
|