Commit
•
c9042dc
1
Parent(s):
d876caa
Improved non-latin `Annotator()` plotting (#7488)
Browse files* Improved non-latin labels Annotator plotting
May resolve https://github.com/ultralytics/yolov5/issues/7460
* Update train.py
* Update train.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add progress arg
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- train.py +5 -3
- utils/general.py +2 -2
- utils/plots.py +4 -3
train.py
CHANGED
@@ -48,13 +48,13 @@ from utils.datasets import create_dataloader
|
|
48 |
from utils.downloads import attempt_download
|
49 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
50 |
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
|
51 |
-
intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
|
52 |
-
print_args, print_mutation, strip_optimizer)
|
53 |
from utils.loggers import Loggers
|
54 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 |
from utils.loss import ComputeLoss
|
56 |
from utils.metrics import fitness
|
57 |
-
from utils.plots import plot_evolve, plot_labels
|
58 |
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
|
59 |
|
60 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
@@ -105,6 +105,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
105 |
init_seeds(1 + RANK)
|
106 |
with torch_distributed_zero_first(LOCAL_RANK):
|
107 |
data_dict = data_dict or check_dataset(data) # check if None
|
|
|
|
|
108 |
train_path, val_path = data_dict['train'], data_dict['val']
|
109 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
110 |
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
|
|
48 |
from utils.downloads import attempt_download
|
49 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
50 |
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
|
51 |
+
intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods,
|
52 |
+
one_cycle, print_args, print_mutation, strip_optimizer)
|
53 |
from utils.loggers import Loggers
|
54 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 |
from utils.loss import ComputeLoss
|
56 |
from utils.metrics import fitness
|
57 |
+
from utils.plots import check_font, plot_evolve, plot_labels
|
58 |
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
|
59 |
|
60 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
|
|
105 |
init_seeds(1 + RANK)
|
106 |
with torch_distributed_zero_first(LOCAL_RANK):
|
107 |
data_dict = data_dict or check_dataset(data) # check if None
|
108 |
+
if not is_ascii(data_dict['names']): # non-latin labels, i.e. asian, arabic, cyrillic
|
109 |
+
check_font('Arial.Unicode.ttf', progress=True)
|
110 |
train_path, val_path = data_dict['train'], data_dict['val']
|
111 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
112 |
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
utils/general.py
CHANGED
@@ -424,13 +424,13 @@ def check_file(file, suffix=''):
|
|
424 |
return files[0] # return file
|
425 |
|
426 |
|
427 |
-
def check_font(font=FONT):
|
428 |
# Download font to CONFIG_DIR if necessary
|
429 |
font = Path(font)
|
430 |
if not font.exists() and not (CONFIG_DIR / font.name).exists():
|
431 |
url = "https://ultralytics.com/assets/" + font.name
|
432 |
LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
|
433 |
-
torch.hub.download_url_to_file(url, str(font), progress=
|
434 |
|
435 |
|
436 |
def check_dataset(data, autodownload=True):
|
|
|
424 |
return files[0] # return file
|
425 |
|
426 |
|
427 |
+
def check_font(font=FONT, progress=False):
|
428 |
# Download font to CONFIG_DIR if necessary
|
429 |
font = Path(font)
|
430 |
if not font.exists() and not (CONFIG_DIR / font.name).exists():
|
431 |
url = "https://ultralytics.com/assets/" + font.name
|
432 |
LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
|
433 |
+
torch.hub.download_url_to_file(url, str(font), progress=progress)
|
434 |
|
435 |
|
436 |
def check_dataset(data, autodownload=True):
|
utils/plots.py
CHANGED
@@ -19,7 +19,7 @@ import torch
|
|
19 |
from PIL import Image, ImageDraw, ImageFont
|
20 |
|
21 |
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
|
22 |
-
increment_path, is_ascii,
|
23 |
from utils.metrics import fitness
|
24 |
|
25 |
# Settings
|
@@ -72,11 +72,12 @@ class Annotator:
|
|
72 |
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
73 |
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
74 |
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
75 |
-
|
|
|
76 |
if self.pil: # use PIL
|
77 |
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
78 |
self.draw = ImageDraw.Draw(self.im)
|
79 |
-
self.font = check_pil_font(font='Arial.Unicode.ttf' if
|
80 |
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
|
81 |
else: # use cv2
|
82 |
self.im = im
|
|
|
19 |
from PIL import Image, ImageDraw, ImageFont
|
20 |
|
21 |
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
|
22 |
+
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
|
23 |
from utils.metrics import fitness
|
24 |
|
25 |
# Settings
|
|
|
72 |
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
73 |
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
74 |
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
75 |
+
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
76 |
+
self.pil = pil or non_ascii
|
77 |
if self.pil: # use PIL
|
78 |
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
79 |
self.draw = ImageDraw.Draw(self.im)
|
80 |
+
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
|
81 |
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
|
82 |
else: # use cv2
|
83 |
self.im = im
|