Training reproducibility improvements (#8213)
Browse files* attempt at reproducibility
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use deterministic algs
* fix everything :)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* revert dataloader changes
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* process_batch as np
* remove newline
* Remove dataloader init fcn
* Update val.py
* Update train.py
* revert additional changes
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update train.py
* Add --seed arg
* Update general.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update train.py
* Update train.py
* Update val.py
* Update train.py
* Update general.py
* Update general.py
* Add deterministic argument to init_seeds()
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
- train.py +2 -1
- utils/general.py +9 -1
@@ -101,7 +101,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
101 |
# Config
|
102 |
plots = not evolve and not opt.noplots # create plots
|
103 |
cuda = device.type != 'cpu'
|
104 |
-
init_seeds(1 + RANK)
|
105 |
with torch_distributed_zero_first(LOCAL_RANK):
|
106 |
data_dict = data_dict or check_dataset(data) # check if None
|
107 |
train_path, val_path = data_dict['train'], data_dict['val']
|
@@ -504,6 +504,7 @@ def parse_opt(known=False):
|
|
504 |
parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
|
505 |
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
|
506 |
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
|
|
|
507 |
parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
|
508 |
|
509 |
# Weights & Biases arguments
|
|
|
101 |
# Config
|
102 |
plots = not evolve and not opt.noplots # create plots
|
103 |
cuda = device.type != 'cpu'
|
104 |
+
init_seeds(opt.seed + 1 + RANK, deterministic=True)
|
105 |
with torch_distributed_zero_first(LOCAL_RANK):
|
106 |
data_dict = data_dict or check_dataset(data) # check if None
|
107 |
train_path, val_path = data_dict['train'], data_dict['val']
|
|
|
504 |
parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
|
505 |
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
|
506 |
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
|
507 |
+
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
|
508 |
parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
|
509 |
|
510 |
# Weights & Biases arguments
|
@@ -195,14 +195,22 @@ def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
|
|
195 |
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
196 |
|
197 |
|
198 |
-
def init_seeds(seed=0):
|
199 |
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
200 |
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
|
201 |
import torch.backends.cudnn as cudnn
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
random.seed(seed)
|
203 |
np.random.seed(seed)
|
204 |
torch.manual_seed(seed)
|
205 |
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
|
|
|
|
|
206 |
|
207 |
|
208 |
def intersect_dicts(da, db, exclude=()):
|
|
|
195 |
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
196 |
|
197 |
|
198 |
+
def init_seeds(seed=0, deterministic=False):
|
199 |
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
200 |
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
|
201 |
import torch.backends.cudnn as cudnn
|
202 |
+
|
203 |
+
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
204 |
+
torch.use_deterministic_algorithms(True)
|
205 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
206 |
+
# os.environ['PYTHONHASHSEED'] = str(seed)
|
207 |
+
|
208 |
random.seed(seed)
|
209 |
np.random.seed(seed)
|
210 |
torch.manual_seed(seed)
|
211 |
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
|
212 |
+
# torch.cuda.manual_seed(seed)
|
213 |
+
# torch.cuda.manual_seed_all(seed) # for multi GPU, exception safe
|
214 |
|
215 |
|
216 |
def intersect_dicts(da, db, exclude=()):
|