Nan
qiningonline
commited on
Commit
•
4a6dfff
1
Parent(s):
97b6b14
Pass `LOCAL_RANK` to `torch_distributed_zero_first()` (#5114)
Browse files
train.py
CHANGED
@@ -99,7 +99,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
99 |
plots = not evolve # create plots
|
100 |
cuda = device.type != 'cpu'
|
101 |
init_seeds(1 + RANK)
|
102 |
-
with torch_distributed_zero_first(
|
103 |
data_dict = data_dict or check_dataset(data) # check if None
|
104 |
train_path, val_path = data_dict['train'], data_dict['val']
|
105 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
@@ -111,7 +111,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
111 |
check_suffix(weights, '.pt') # check weights
|
112 |
pretrained = weights.endswith('.pt')
|
113 |
if pretrained:
|
114 |
-
with torch_distributed_zero_first(
|
115 |
weights = attempt_download(weights) # download if not found locally
|
116 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
117 |
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
@@ -208,7 +208,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
208 |
|
209 |
# Trainloader
|
210 |
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
211 |
-
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=
|
212 |
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
|
213 |
prefix=colorstr('train: '))
|
214 |
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|
|
|
99 |
plots = not evolve # create plots
|
100 |
cuda = device.type != 'cpu'
|
101 |
init_seeds(1 + RANK)
|
102 |
+
with torch_distributed_zero_first(LOCAL_RANK):
|
103 |
data_dict = data_dict or check_dataset(data) # check if None
|
104 |
train_path, val_path = data_dict['train'], data_dict['val']
|
105 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
|
|
111 |
check_suffix(weights, '.pt') # check weights
|
112 |
pretrained = weights.endswith('.pt')
|
113 |
if pretrained:
|
114 |
+
with torch_distributed_zero_first(LOCAL_RANK):
|
115 |
weights = attempt_download(weights) # download if not found locally
|
116 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
117 |
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
|
|
208 |
|
209 |
# Trainloader
|
210 |
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
211 |
+
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
|
212 |
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
|
213 |
prefix=colorstr('train: '))
|
214 |
mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
|