Nan qiningonline commited on
Commit
4a6dfff
1 Parent(s): 97b6b14

Pass `LOCAL_RANK` to `torch_distributed_zero_first()` (#5114)

Browse files
Files changed (1) hide show
  1. train.py +3 -3
train.py CHANGED
@@ -99,7 +99,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
99
  plots = not evolve # create plots
100
  cuda = device.type != 'cpu'
101
  init_seeds(1 + RANK)
102
- with torch_distributed_zero_first(RANK):
103
  data_dict = data_dict or check_dataset(data) # check if None
104
  train_path, val_path = data_dict['train'], data_dict['val']
105
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
@@ -111,7 +111,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
111
  check_suffix(weights, '.pt') # check weights
112
  pretrained = weights.endswith('.pt')
113
  if pretrained:
114
- with torch_distributed_zero_first(RANK):
115
  weights = attempt_download(weights) # download if not found locally
116
  ckpt = torch.load(weights, map_location=device) # load checkpoint
117
  model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
@@ -208,7 +208,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
208
 
209
  # Trainloader
210
  train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
211
- hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
212
  workers=workers, image_weights=opt.image_weights, quad=opt.quad,
213
  prefix=colorstr('train: '))
214
  mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
 
99
  plots = not evolve # create plots
100
  cuda = device.type != 'cpu'
101
  init_seeds(1 + RANK)
102
+ with torch_distributed_zero_first(LOCAL_RANK):
103
  data_dict = data_dict or check_dataset(data) # check if None
104
  train_path, val_path = data_dict['train'], data_dict['val']
105
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
 
111
  check_suffix(weights, '.pt') # check weights
112
  pretrained = weights.endswith('.pt')
113
  if pretrained:
114
+ with torch_distributed_zero_first(LOCAL_RANK):
115
  weights = attempt_download(weights) # download if not found locally
116
  ckpt = torch.load(weights, map_location=device) # load checkpoint
117
  model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
 
208
 
209
  # Trainloader
210
  train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
211
+ hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
212
  workers=workers, image_weights=opt.image_weights, quad=opt.quad,
213
  prefix=colorstr('train: '))
214
  mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class