glenn-jocher commited on
Commit
f000714
β€’
1 Parent(s): 3356f26

Refactor collections and fstrings (#7821)

Browse files

* Update torch_utils.py

* Additional code refactoring

* tuples to sets

* Cleanup

detect.py CHANGED
@@ -160,15 +160,15 @@ def run(
160
  if save_txt: # Write to file
161
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
162
  line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
163
- with open(txt_path + '.txt', 'a') as f:
164
  f.write(('%g ' * len(line)).rstrip() % line + '\n')
165
 
166
  if save_img or save_crop or view_img: # Add bbox to image
167
  c = int(cls) # integer class
168
  label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
169
  annotator.box_label(xyxy, label, color=colors(c, True))
170
- if save_crop:
171
- save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
172
 
173
  # Stream results
174
  im0 = annotator.result()
 
160
  if save_txt: # Write to file
161
  xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
162
  line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
163
+ with open(f'{txt_path}.txt', 'a') as f:
164
  f.write(('%g ' * len(line)).rstrip() % line + '\n')
165
 
166
  if save_img or save_crop or view_img: # Add bbox to image
167
  c = int(cls) # integer class
168
  label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
169
  annotator.box_label(xyxy, label, color=colors(c, True))
170
+ if save_crop:
171
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
172
 
173
  # Stream results
174
  im0 = annotator.result()
export.py CHANGED
@@ -175,7 +175,7 @@ def export_openvino(model, im, file, half, prefix=colorstr('OpenVINO:')):
175
  import openvino.inference_engine as ie
176
 
177
  LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
178
- f = str(file).replace('.pt', '_openvino_model' + os.sep)
179
 
180
  cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
181
  subprocess.check_output(cmd, shell=True)
@@ -385,7 +385,7 @@ def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
385
  cmd = 'edgetpu_compiler --version'
386
  help_url = 'https://coral.ai/docs/edgetpu/compiler/'
387
  assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
388
- if subprocess.run(cmd + ' >/dev/null', shell=True).returncode != 0:
389
  LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
390
  sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
391
  for c in (
@@ -419,7 +419,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
419
  LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
420
  f = str(file).replace('.pt', '_web_model') # js dir
421
  f_pb = file.with_suffix('.pb') # *.pb path
422
- f_json = f + '/model.json' # *.json path
423
 
424
  cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
425
  f'--output_node_names="Identity,Identity_1,Identity_2,Identity_3" {f_pb} {f}'
 
175
  import openvino.inference_engine as ie
176
 
177
  LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
178
+ f = str(file).replace('.pt', f'_openvino_model{os.sep}')
179
 
180
  cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
181
  subprocess.check_output(cmd, shell=True)
 
385
  cmd = 'edgetpu_compiler --version'
386
  help_url = 'https://coral.ai/docs/edgetpu/compiler/'
387
  assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
388
+ if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
389
  LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
390
  sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
391
  for c in (
 
419
  LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
420
  f = str(file).replace('.pt', '_web_model') # js dir
421
  f_pb = file.with_suffix('.pb') # *.pb path
422
+ f_json = f'{f}/model.json' # *.json path
423
 
424
  cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
425
  f'--output_node_names="Identity,Identity_1,Identity_2,Identity_3" {f_pb} {f}'
train.py CHANGED
@@ -88,7 +88,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
88
 
89
  # Loggers
90
  data_dict = None
91
- if RANK in [-1, 0]:
92
  loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
93
  if loggers.wandb:
94
  data_dict = loggers.wandb.data_dict
@@ -181,7 +181,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
181
  scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
182
 
183
  # EMA
184
- ema = ModelEMA(model) if RANK in [-1, 0] else None
185
 
186
  # Resume
187
  start_epoch, best_fitness = 0, 0.0
@@ -238,7 +238,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
238
  assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
239
 
240
  # Process 0
241
- if RANK in [-1, 0]:
242
  val_loader = create_dataloader(val_path,
243
  imgsz,
244
  batch_size // WORLD_SIZE * 2,
@@ -320,7 +320,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
320
  train_loader.sampler.set_epoch(epoch)
321
  pbar = enumerate(train_loader)
322
  LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
323
- if RANK in (-1, 0):
324
  pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
325
  optimizer.zero_grad()
326
  for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
@@ -369,7 +369,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
369
  last_opt_step = ni
370
 
371
  # Log
372
- if RANK in (-1, 0):
373
  mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
374
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
375
  pbar.set_description(('%10s' * 2 + '%10.4g' * 5) %
@@ -383,7 +383,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
383
  lr = [x['lr'] for x in optimizer.param_groups] # for loggers
384
  scheduler.step()
385
 
386
- if RANK in (-1, 0):
387
  # mAP
388
  callbacks.run('on_train_epoch_end', epoch=epoch)
389
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
@@ -444,7 +444,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
444
 
445
  # end epoch ----------------------------------------------------------------------------------------------------
446
  # end training -----------------------------------------------------------------------------------------------------
447
- if RANK in (-1, 0):
448
  LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
449
  for f in last, best:
450
  if f.exists():
@@ -522,7 +522,7 @@ def parse_opt(known=False):
522
 
523
  def main(opt, callbacks=Callbacks()):
524
  # Checks
525
- if RANK in (-1, 0):
526
  print_args(vars(opt))
527
  check_git_status()
528
  check_requirements(exclude=['thop'])
 
88
 
89
  # Loggers
90
  data_dict = None
91
+ if RANK in {-1, 0}:
92
  loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
93
  if loggers.wandb:
94
  data_dict = loggers.wandb.data_dict
 
181
  scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
182
 
183
  # EMA
184
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
185
 
186
  # Resume
187
  start_epoch, best_fitness = 0, 0.0
 
238
  assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
239
 
240
  # Process 0
241
+ if RANK in {-1, 0}:
242
  val_loader = create_dataloader(val_path,
243
  imgsz,
244
  batch_size // WORLD_SIZE * 2,
 
320
  train_loader.sampler.set_epoch(epoch)
321
  pbar = enumerate(train_loader)
322
  LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
323
+ if RANK in {-1, 0}:
324
  pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
325
  optimizer.zero_grad()
326
  for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
 
369
  last_opt_step = ni
370
 
371
  # Log
372
+ if RANK in {-1, 0}:
373
  mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
374
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
375
  pbar.set_description(('%10s' * 2 + '%10.4g' * 5) %
 
383
  lr = [x['lr'] for x in optimizer.param_groups] # for loggers
384
  scheduler.step()
385
 
386
+ if RANK in {-1, 0}:
387
  # mAP
388
  callbacks.run('on_train_epoch_end', epoch=epoch)
389
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
 
444
 
445
  # end epoch ----------------------------------------------------------------------------------------------------
446
  # end training -----------------------------------------------------------------------------------------------------
447
+ if RANK in {-1, 0}:
448
  LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
449
  for f in last, best:
450
  if f.exists():
 
522
 
523
  def main(opt, callbacks=Callbacks()):
524
  # Checks
525
+ if RANK in {-1, 0}:
526
  print_args(vars(opt))
527
  check_git_status()
528
  check_requirements(exclude=['thop'])
utils/autoanchor.py CHANGED
@@ -104,7 +104,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
104
  s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
105
  f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
106
  f'past_thr={x[x > thr].mean():.3f}-mean: '
107
- for i, x in enumerate(k):
108
  s += '%i,%i, ' % (round(x[0]), round(x[1]))
109
  if verbose:
110
  LOGGER.info(s[:-2])
 
104
  s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
105
  f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
106
  f'past_thr={x[x > thr].mean():.3f}-mean: '
107
+ for x in k:
108
  s += '%i,%i, ' % (round(x[0]), round(x[1]))
109
  if verbose:
110
  LOGGER.info(s[:-2])
utils/dataloaders.py CHANGED
@@ -57,9 +57,7 @@ def exif_size(img):
57
  s = img.size # (width, height)
58
  try:
59
  rotation = dict(img._getexif().items())[orientation]
60
- if rotation == 6: # rotation 270
61
- s = (s[1], s[0])
62
- elif rotation == 8: # rotation 90
63
  s = (s[1], s[0])
64
  except Exception:
65
  pass
@@ -156,7 +154,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
156
  return len(self.batch_sampler.sampler)
157
 
158
  def __iter__(self):
159
- for i in range(len(self)):
160
  yield next(self.iterator)
161
 
162
 
@@ -224,10 +222,9 @@ class LoadImages:
224
  self.cap.release()
225
  if self.count == self.nf: # last video
226
  raise StopIteration
227
- else:
228
- path = self.files[self.count]
229
- self.new_video(path)
230
- ret_val, img0 = self.cap.read()
231
 
232
  self.frame += 1
233
  s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
@@ -390,7 +387,7 @@ class LoadStreams:
390
 
391
  def img2label_paths(img_paths):
392
  # Define label paths as a function of image paths
393
- sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
394
  return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
395
 
396
 
@@ -456,7 +453,7 @@ class LoadImagesAndLabels(Dataset):
456
 
457
  # Display cache
458
  nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
459
- if exists and LOCAL_RANK in (-1, 0):
460
  d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
461
  tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
462
  if cache['msgs']:
 
57
  s = img.size # (width, height)
58
  try:
59
  rotation = dict(img._getexif().items())[orientation]
60
+ if rotation in [6, 8]: # rotation 270 or 90
 
 
61
  s = (s[1], s[0])
62
  except Exception:
63
  pass
 
154
  return len(self.batch_sampler.sampler)
155
 
156
  def __iter__(self):
157
+ for _ in range(len(self)):
158
  yield next(self.iterator)
159
 
160
 
 
222
  self.cap.release()
223
  if self.count == self.nf: # last video
224
  raise StopIteration
225
+ path = self.files[self.count]
226
+ self.new_video(path)
227
+ ret_val, img0 = self.cap.read()
 
228
 
229
  self.frame += 1
230
  s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
 
387
 
388
  def img2label_paths(img_paths):
389
  # Define label paths as a function of image paths
390
+ sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
391
  return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
392
 
393
 
 
453
 
454
  # Display cache
455
  nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
456
+ if exists and LOCAL_RANK in {-1, 0}:
457
  d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
458
  tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
459
  if cache['msgs']:
utils/general.py CHANGED
@@ -84,7 +84,7 @@ def set_logging(name=None, verbose=VERBOSE):
84
  for h in logging.root.handlers:
85
  logging.root.removeHandler(h) # remove all handlers associated with the root logger object
86
  rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
87
- level = logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING
88
  log = logging.getLogger(name)
89
  log.setLevel(level)
90
  handler = logging.StreamHandler()
 
84
  for h in logging.root.handlers:
85
  logging.root.removeHandler(h) # remove all handlers associated with the root logger object
86
  rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
87
+ level = logging.INFO if verbose and rank in {-1, 0} else logging.WARNING
88
  log = logging.getLogger(name)
89
  log.setLevel(level)
90
  handler = logging.StreamHandler()
utils/loggers/__init__.py CHANGED
@@ -22,7 +22,7 @@ try:
22
  import wandb
23
 
24
  assert hasattr(wandb, '__version__') # verify package import not local dir
25
- if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in [0, -1]:
26
  try:
27
  wandb_login_success = wandb.login(timeout=30)
28
  except wandb.errors.UsageError: # known non-TTY terminal issue
@@ -176,7 +176,7 @@ class Loggers():
176
  if not self.opt.evolve:
177
  wandb.log_artifact(str(best if best.exists() else last),
178
  type='model',
179
- name='run_' + self.wandb.wandb_run.id + '_model',
180
  aliases=['latest', 'best', 'stripped'])
181
  self.wandb.finish_run()
182
 
 
22
  import wandb
23
 
24
  assert hasattr(wandb, '__version__') # verify package import not local dir
25
+ if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in {0, -1}:
26
  try:
27
  wandb_login_success = wandb.login(timeout=30)
28
  except wandb.errors.UsageError: # known non-TTY terminal issue
 
176
  if not self.opt.evolve:
177
  wandb.log_artifact(str(best if best.exists() else last),
178
  type='model',
179
+ name=f'run_{self.wandb.wandb_run.id}_model',
180
  aliases=['latest', 'best', 'stripped'])
181
  self.wandb.finish_run()
182
 
utils/metrics.py CHANGED
@@ -55,32 +55,31 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
55
  i = pred_cls == c
56
  n_l = nt[ci] # number of labels
57
  n_p = i.sum() # number of predictions
58
-
59
  if n_p == 0 or n_l == 0:
60
  continue
61
- else:
62
- # Accumulate FPs and TPs
63
- fpc = (1 - tp[i]).cumsum(0)
64
- tpc = tp[i].cumsum(0)
65
 
66
- # Recall
67
- recall = tpc / (n_l + eps) # recall curve
68
- r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
 
 
 
 
69
 
70
- # Precision
71
- precision = tpc / (tpc + fpc) # precision curve
72
- p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
73
 
74
- # AP from recall-precision curve
75
- for j in range(tp.shape[1]):
76
- ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
77
- if plot and j == 0:
78
- py.append(np.interp(px, mrec, mpre)) # precision at [email protected]
79
 
80
  # Compute F1 (harmonic mean of precision and recall)
81
  f1 = 2 * p * r / (p + r + eps)
82
  names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
83
- names = {i: v for i, v in enumerate(names)} # to dict
84
  if plot:
85
  plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
86
  plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
@@ -314,7 +313,7 @@ def wh_iou(wh1, wh2):
314
  # Plots ----------------------------------------------------------------------------------------------------------------
315
 
316
 
317
- def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
318
  # Precision-recall curve
319
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
320
  py = np.stack(py, axis=1)
@@ -331,11 +330,11 @@ def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
331
  ax.set_xlim(0, 1)
332
  ax.set_ylim(0, 1)
333
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
334
- fig.savefig(Path(save_dir), dpi=250)
335
  plt.close()
336
 
337
 
338
- def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
339
  # Metric-confidence curve
340
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
341
 
@@ -352,5 +351,5 @@ def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence'
352
  ax.set_xlim(0, 1)
353
  ax.set_ylim(0, 1)
354
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
355
- fig.savefig(Path(save_dir), dpi=250)
356
  plt.close()
 
55
  i = pred_cls == c
56
  n_l = nt[ci] # number of labels
57
  n_p = i.sum() # number of predictions
 
58
  if n_p == 0 or n_l == 0:
59
  continue
 
 
 
 
60
 
61
+ # Accumulate FPs and TPs
62
+ fpc = (1 - tp[i]).cumsum(0)
63
+ tpc = tp[i].cumsum(0)
64
+
65
+ # Recall
66
+ recall = tpc / (n_l + eps) # recall curve
67
+ r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
68
 
69
+ # Precision
70
+ precision = tpc / (tpc + fpc) # precision curve
71
+ p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
72
 
73
+ # AP from recall-precision curve
74
+ for j in range(tp.shape[1]):
75
+ ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
76
+ if plot and j == 0:
77
+ py.append(np.interp(px, mrec, mpre)) # precision at [email protected]
78
 
79
  # Compute F1 (harmonic mean of precision and recall)
80
  f1 = 2 * p * r / (p + r + eps)
81
  names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
82
+ names = dict(enumerate(names)) # to dict
83
  if plot:
84
  plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
85
  plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
 
313
  # Plots ----------------------------------------------------------------------------------------------------------------
314
 
315
 
316
+ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
317
  # Precision-recall curve
318
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
319
  py = np.stack(py, axis=1)
 
330
  ax.set_xlim(0, 1)
331
  ax.set_ylim(0, 1)
332
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
333
+ fig.savefig(save_dir, dpi=250)
334
  plt.close()
335
 
336
 
337
+ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
338
  # Metric-confidence curve
339
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
340
 
 
351
  ax.set_xlim(0, 1)
352
  ax.set_ylim(0, 1)
353
  plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
354
+ fig.savefig(save_dir, dpi=250)
355
  plt.close()
utils/torch_utils.py CHANGED
@@ -50,9 +50,9 @@ def device_count():
50
 
51
 
52
  def select_device(device='', batch_size=0, newline=True):
53
- # device = 'cpu' or '0' or '0,1,2,3'
54
  s = f'YOLOv5 πŸš€ {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
55
- device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
56
  cpu = device == 'cpu'
57
  if cpu:
58
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
@@ -97,7 +97,8 @@ def profile(input, ops, n=10, device=None):
97
  # profile(input, [m1, m2], n=100) # profile over 100 iterations
98
 
99
  results = []
100
- device = device or select_device()
 
101
  print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
102
  f"{'input':>24s}{'output':>24s}")
103
 
@@ -127,9 +128,8 @@ def profile(input, ops, n=10, device=None):
127
  tf += (t[1] - t[0]) * 1000 / n # ms per op forward
128
  tb += (t[2] - t[1]) * 1000 / n # ms per op backward
129
  mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
130
- s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
131
- s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
132
- p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
133
  print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
134
  results.append([p, flops, mem, tf, tb, s_in, s_out])
135
  except Exception as e:
@@ -227,7 +227,7 @@ def model_info(model, verbose=False, img_size=640):
227
  flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
228
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
229
  fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
230
- except (ImportError, Exception):
231
  fs = ''
232
 
233
  name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
@@ -238,13 +238,12 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
238
  # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
239
  if ratio == 1.0:
240
  return img
241
- else:
242
- h, w = img.shape[2:]
243
- s = (int(h * ratio), int(w * ratio)) # new size
244
- img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
245
- if not same_shape: # pad/crop img
246
- h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
247
- return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
248
 
249
 
250
  def copy_attr(a, b, include=(), exclude=()):
 
50
 
51
 
52
  def select_device(device='', batch_size=0, newline=True):
53
+ # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
54
  s = f'YOLOv5 πŸš€ {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
55
+ device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
56
  cpu = device == 'cpu'
57
  if cpu:
58
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
 
97
  # profile(input, [m1, m2], n=100) # profile over 100 iterations
98
 
99
  results = []
100
+ if not isinstance(device, torch.device):
101
+ device = select_device(device)
102
  print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
103
  f"{'input':>24s}{'output':>24s}")
104
 
 
128
  tf += (t[1] - t[0]) * 1000 / n # ms per op forward
129
  tb += (t[2] - t[1]) * 1000 / n # ms per op backward
130
  mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
131
+ s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
132
+ p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
 
133
  print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
134
  results.append([p, flops, mem, tf, tb, s_in, s_out])
135
  except Exception as e:
 
227
  flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
228
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
229
  fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
230
+ except Exception:
231
  fs = ''
232
 
233
  name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
 
238
  # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
239
  if ratio == 1.0:
240
  return img
241
+ h, w = img.shape[2:]
242
+ s = (int(h * ratio), int(w * ratio)) # new size
243
+ img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
244
+ if not same_shape: # pad/crop img
245
+ h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
246
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
 
247
 
248
 
249
  def copy_attr(a, b, include=(), exclude=()):