glenn-jocher
commited on
Commit
•
140d84c
1
Parent(s):
ea34f84
comment updates
Browse files- train.py +2 -2
- utils/utils.py +2 -5
train.py
CHANGED
@@ -152,13 +152,13 @@ def train(hyp):
|
|
152 |
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
153 |
|
154 |
# Distributed training
|
155 |
-
if device.type != 'cpu' and torch.cuda.device_count() > 1 and
|
156 |
dist.init_process_group(backend='nccl', # distributed backend
|
157 |
init_method='tcp://127.0.0.1:9999', # init method
|
158 |
world_size=1, # number of nodes
|
159 |
rank=0) # node rank
|
|
|
160 |
model = torch.nn.parallel.DistributedDataParallel(model)
|
161 |
-
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
|
162 |
|
163 |
# Trainloader
|
164 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
|
|
152 |
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
153 |
|
154 |
# Distributed training
|
155 |
+
if device.type != 'cpu' and torch.cuda.device_count() > 1 and dist.is_available():
|
156 |
dist.init_process_group(backend='nccl', # distributed backend
|
157 |
init_method='tcp://127.0.0.1:9999', # init method
|
158 |
world_size=1, # number of nodes
|
159 |
rank=0) # node rank
|
160 |
+
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) # requires world_size > 1
|
161 |
model = torch.nn.parallel.DistributedDataParallel(model)
|
|
|
162 |
|
163 |
# Trainloader
|
164 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
utils/utils.py
CHANGED
@@ -503,6 +503,7 @@ def build_targets(p, targets, model):
|
|
503 |
off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
|
504 |
at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
|
505 |
|
|
|
506 |
style = 'rect4'
|
507 |
for i in range(det.nl):
|
508 |
anchors = det.anchors[i]
|
@@ -517,7 +518,6 @@ def build_targets(p, targets, model):
|
|
517 |
a, t = at[j], t.repeat(na, 1, 1)[j] # filter
|
518 |
|
519 |
# overlaps
|
520 |
-
g = 0.5 # offset
|
521 |
gxy = t[:, 2:4] # grid xy
|
522 |
z = torch.zeros_like(gxy)
|
523 |
if style == 'rect2':
|
@@ -878,10 +878,7 @@ def fitness(x):
|
|
878 |
|
879 |
|
880 |
def output_to_target(output, width, height):
|
881 |
-
|
882 |
-
Convert a YOLO model output to target format
|
883 |
-
[batch_id, class_id, x, y, w, h, conf]
|
884 |
-
"""
|
885 |
if isinstance(output, torch.Tensor):
|
886 |
output = output.cpu().numpy()
|
887 |
|
|
|
503 |
off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
|
504 |
at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
|
505 |
|
506 |
+
g = 0.5 # offset
|
507 |
style = 'rect4'
|
508 |
for i in range(det.nl):
|
509 |
anchors = det.anchors[i]
|
|
|
518 |
a, t = at[j], t.repeat(na, 1, 1)[j] # filter
|
519 |
|
520 |
# overlaps
|
|
|
521 |
gxy = t[:, 2:4] # grid xy
|
522 |
z = torch.zeros_like(gxy)
|
523 |
if style == 'rect2':
|
|
|
878 |
|
879 |
|
880 |
def output_to_target(output, width, height):
|
881 |
+
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
|
|
|
|
|
|
882 |
if isinstance(output, torch.Tensor):
|
883 |
output = output.cpu().numpy()
|
884 |
|