bilzard
commited on
Commit
•
e1dc894
1
Parent(s):
d95978a
Enable AdamW optimizer (#6152)
Browse files
train.py
CHANGED
@@ -22,7 +22,7 @@ import torch.nn as nn
|
|
22 |
import yaml
|
23 |
from torch.cuda import amp
|
24 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
25 |
-
from torch.optim import SGD, Adam, lr_scheduler
|
26 |
from tqdm import tqdm
|
27 |
|
28 |
FILE = Path(__file__).resolve()
|
@@ -155,8 +155,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
155 |
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
156 |
g1.append(v.weight)
|
157 |
|
158 |
-
if opt.
|
159 |
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
|
|
|
|
160 |
else:
|
161 |
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
162 |
|
@@ -460,7 +462,7 @@ def parse_opt(known=False):
|
|
460 |
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
461 |
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
462 |
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
|
463 |
-
parser.add_argument('--
|
464 |
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
465 |
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
466 |
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
|
|
|
22 |
import yaml
|
23 |
from torch.cuda import amp
|
24 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
25 |
+
from torch.optim import SGD, Adam, AdamW, lr_scheduler
|
26 |
from tqdm import tqdm
|
27 |
|
28 |
FILE = Path(__file__).resolve()
|
|
|
155 |
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
156 |
g1.append(v.weight)
|
157 |
|
158 |
+
if opt.optimizer == 'Adam':
|
159 |
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
160 |
+
elif opt.optimizer == 'AdamW':
|
161 |
+
optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
162 |
else:
|
163 |
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
164 |
|
|
|
462 |
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
463 |
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
464 |
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
|
465 |
+
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
|
466 |
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
467 |
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
468 |
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
|