glenn-jocher pre-commit-ci[bot] commited on
Commit
c9042dc
1 Parent(s): d876caa

Improved non-latin `Annotator()` plotting (#7488)

Browse files

* Improved non-latin labels Annotator plotting

May resolve https://github.com/ultralytics/yolov5/issues/7460

* Update train.py

* Update train.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add progress arg

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (3) hide show
  1. train.py +5 -3
  2. utils/general.py +2 -2
  3. utils/plots.py +4 -3
train.py CHANGED
@@ -48,13 +48,13 @@ from utils.datasets import create_dataloader
48
  from utils.downloads import attempt_download
49
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50
  check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
51
- intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
52
- print_args, print_mutation, strip_optimizer)
53
  from utils.loggers import Loggers
54
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
55
  from utils.loss import ComputeLoss
56
  from utils.metrics import fitness
57
- from utils.plots import plot_evolve, plot_labels
58
  from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
59
 
60
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
@@ -105,6 +105,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
105
  init_seeds(1 + RANK)
106
  with torch_distributed_zero_first(LOCAL_RANK):
107
  data_dict = data_dict or check_dataset(data) # check if None
 
 
108
  train_path, val_path = data_dict['train'], data_dict['val']
109
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
110
  names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
 
48
  from utils.downloads import attempt_download
49
  from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50
  check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
51
+ intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods,
52
+ one_cycle, print_args, print_mutation, strip_optimizer)
53
  from utils.loggers import Loggers
54
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
55
  from utils.loss import ComputeLoss
56
  from utils.metrics import fitness
57
+ from utils.plots import check_font, plot_evolve, plot_labels
58
  from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
59
 
60
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
 
105
  init_seeds(1 + RANK)
106
  with torch_distributed_zero_first(LOCAL_RANK):
107
  data_dict = data_dict or check_dataset(data) # check if None
108
+ if not is_ascii(data_dict['names']): # non-latin labels, i.e. asian, arabic, cyrillic
109
+ check_font('Arial.Unicode.ttf', progress=True)
110
  train_path, val_path = data_dict['train'], data_dict['val']
111
  nc = 1 if single_cls else int(data_dict['nc']) # number of classes
112
  names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
utils/general.py CHANGED
@@ -424,13 +424,13 @@ def check_file(file, suffix=''):
424
  return files[0] # return file
425
 
426
 
427
- def check_font(font=FONT):
428
  # Download font to CONFIG_DIR if necessary
429
  font = Path(font)
430
  if not font.exists() and not (CONFIG_DIR / font.name).exists():
431
  url = "https://ultralytics.com/assets/" + font.name
432
  LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
433
- torch.hub.download_url_to_file(url, str(font), progress=False)
434
 
435
 
436
  def check_dataset(data, autodownload=True):
 
424
  return files[0] # return file
425
 
426
 
427
+ def check_font(font=FONT, progress=False):
428
  # Download font to CONFIG_DIR if necessary
429
  font = Path(font)
430
  if not font.exists() and not (CONFIG_DIR / font.name).exists():
431
  url = "https://ultralytics.com/assets/" + font.name
432
  LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
433
+ torch.hub.download_url_to_file(url, str(font), progress=progress)
434
 
435
 
436
  def check_dataset(data, autodownload=True):
utils/plots.py CHANGED
@@ -19,7 +19,7 @@ import torch
19
  from PIL import Image, ImageDraw, ImageFont
20
 
21
  from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
22
- increment_path, is_ascii, is_chinese, try_except, xywh2xyxy, xyxy2xywh)
23
  from utils.metrics import fitness
24
 
25
  # Settings
@@ -72,11 +72,12 @@ class Annotator:
72
  # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
73
  def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
74
  assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
75
- self.pil = pil or not is_ascii(example) or is_chinese(example)
 
76
  if self.pil: # use PIL
77
  self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
78
  self.draw = ImageDraw.Draw(self.im)
79
- self.font = check_pil_font(font='Arial.Unicode.ttf' if is_chinese(example) else font,
80
  size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
81
  else: # use cv2
82
  self.im = im
 
19
  from PIL import Image, ImageDraw, ImageFont
20
 
21
  from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
22
+ increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
23
  from utils.metrics import fitness
24
 
25
  # Settings
 
72
  # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
73
  def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
74
  assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
75
+ non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
76
+ self.pil = pil or non_ascii
77
  if self.pil: # use PIL
78
  self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
79
  self.draw = ImageDraw.Draw(self.im)
80
+ self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
81
  size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
82
  else: # use cv2
83
  self.im = im