Ayush Chaurasia pre-commit-ci[bot] glenn-jocher commited on
Commit
27d831b
1 Parent(s): 36f64a9

Training reproducibility improvements (#8213)

Browse files

* attempt at reproducibility

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use deterministic algs

* fix everything :)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert dataloader changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* process_batch as np

* remove newline

* Remove dataloader init fcn

* Update val.py

* Update train.py

* revert additional changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Add --seed arg

* Update general.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Update train.py

* Update val.py

* Update train.py

* Update general.py

* Update general.py

* Add deterministic argument to init_seeds()

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>

Files changed (2) hide show
  1. train.py +2 -1
  2. utils/general.py +9 -1
train.py CHANGED
@@ -101,7 +101,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
101
  # Config
102
  plots = not evolve and not opt.noplots # create plots
103
  cuda = device.type != 'cpu'
104
- init_seeds(1 + RANK)
105
  with torch_distributed_zero_first(LOCAL_RANK):
106
  data_dict = data_dict or check_dataset(data) # check if None
107
  train_path, val_path = data_dict['train'], data_dict['val']
@@ -504,6 +504,7 @@ def parse_opt(known=False):
504
  parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
505
  parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
506
  parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
 
507
  parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
508
 
509
  # Weights & Biases arguments
 
101
  # Config
102
  plots = not evolve and not opt.noplots # create plots
103
  cuda = device.type != 'cpu'
104
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
105
  with torch_distributed_zero_first(LOCAL_RANK):
106
  data_dict = data_dict or check_dataset(data) # check if None
107
  train_path, val_path = data_dict['train'], data_dict['val']
 
504
  parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
505
  parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
506
  parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
507
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
508
  parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
509
 
510
  # Weights & Biases arguments
utils/general.py CHANGED
@@ -195,14 +195,22 @@ def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
195
  LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
196
 
197
 
198
- def init_seeds(seed=0):
199
  # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
200
  # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
201
  import torch.backends.cudnn as cudnn
 
 
 
 
 
 
202
  random.seed(seed)
203
  np.random.seed(seed)
204
  torch.manual_seed(seed)
205
  cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
 
 
206
 
207
 
208
  def intersect_dicts(da, db, exclude=()):
 
195
  LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
196
 
197
 
198
+ def init_seeds(seed=0, deterministic=False):
199
  # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
200
  # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
201
  import torch.backends.cudnn as cudnn
202
+
203
+ if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
204
+ torch.use_deterministic_algorithms(True)
205
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
206
+ # os.environ['PYTHONHASHSEED'] = str(seed)
207
+
208
  random.seed(seed)
209
  np.random.seed(seed)
210
  torch.manual_seed(seed)
211
  cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
212
+ # torch.cuda.manual_seed(seed)
213
+ # torch.cuda.manual_seed_all(seed) # for multi GPU, exception safe
214
 
215
 
216
  def intersect_dicts(da, db, exclude=()):