|
import threading |
|
import os |
|
import contextlib |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image, ImageDraw, ImageFont, ExifTags |
|
from PIL import __version__ as pil_version |
|
from multiprocessing.pool import ThreadPool |
|
import numpy as np |
|
from itertools import repeat |
|
import glob |
|
import cv2 |
|
import tempfile |
|
import hashlib |
|
from pathlib import Path |
|
import time |
|
import torchvision |
|
import math |
|
import re |
|
from typing import List, Union, Dict |
|
import pkg_resources as pkg |
|
from types import SimpleNamespace |
|
from torch.utils.data import Dataset, DataLoader |
|
from tqdm import tqdm |
|
import random |
|
import yaml |
|
import logging.config |
|
import sys |
|
import pathlib |
|
CURRENT_DIR = pathlib.Path(__file__).parent |
|
sys.path.append(str(CURRENT_DIR)) |
|
|
|
LOGGING_NAME = 'ultralytics' |
|
LOGGER = logging.getLogger(LOGGING_NAME) |
|
for fn in LOGGER.info, LOGGER.warning: |
|
setattr(LOGGER, fn.__name__, lambda x: fn(x)) |
|
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" |
|
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" |
|
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' |
|
NUM_THREADS = min(8, os.cpu_count()) |
|
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" |
|
_formats = ["xyxy", "xywh", "ltwh"] |
|
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'} |
|
CFG_FRACTION_KEYS = { |
|
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma', |
|
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', |
|
'mixup', 'copy_paste', 'conf', 'iou'} |
|
CFG_INT_KEYS = { |
|
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride', |
|
'line_thickness', 'workspace', 'nbs'} |
|
CFG_BOOL_KEYS = { |
|
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr', |
|
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', |
|
'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras', |
|
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'} |
|
|
|
for orientation in ExifTags.TAGS.keys(): |
|
if ExifTags.TAGS[orientation] == 'Orientation': |
|
break |
|
|
|
def segments2boxes(segments): |
|
""" |
|
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) |
|
|
|
Args: |
|
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates |
|
|
|
Returns: |
|
(np.ndarray): the xywh coordinates of the bounding boxes. |
|
""" |
|
boxes = [] |
|
for s in segments: |
|
x, y = s.T |
|
boxes.append([x.min(), y.min(), x.max(), y.max()]) |
|
return xyxy2xywh(np.array(boxes)) |
|
|
|
|
|
def check_version( |
|
current: str = "0.0.0", |
|
minimum: str = "0.0.0", |
|
name: str = "version ", |
|
pinned: bool = False, |
|
hard: bool = False, |
|
verbose: bool = False, |
|
) -> bool: |
|
""" |
|
Check current version against the required minimum version. |
|
|
|
Args: |
|
current (str): Current version. |
|
minimum (str): Required minimum version. |
|
name (str): Name to be used in warning message. |
|
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied. |
|
hard (bool): If True, raise an AssertionError if the minimum version is not met. |
|
verbose (bool): If True, print warning message if minimum version is not met. |
|
|
|
Returns: |
|
bool: True if minimum version is met, False otherwise. |
|
""" |
|
current, minimum = (pkg.parse_version(x) for x in (current, minimum)) |
|
result = (current == minimum) if pinned else (current >= minimum) |
|
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed" |
|
if verbose and not result: |
|
LOGGER.warning(warning_message) |
|
return result |
|
|
|
|
|
TORCH_1_9 = check_version(torch.__version__, '1.9.0') |
|
|
|
|
|
def smart_inference_mode(): |
|
|
|
def decorate(fn): |
|
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) |
|
|
|
return decorate |
|
|
|
|
|
def box_iou(box1, box2, eps=1e-7): |
|
|
|
""" |
|
Return intersection-over-union (Jaccard index) of boxes. |
|
Both sets of boxes are expected to be in (x1, y1, x2, y2) format. |
|
Arguments: |
|
box1 (Tensor[N, 4]) |
|
box2 (Tensor[M, 4]) |
|
Returns: |
|
iou (Tensor[N, M]): the NxM matrix containing the pairwise |
|
IoU values for every element in boxes1 and boxes2 |
|
""" |
|
|
|
|
|
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2) |
|
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2) |
|
|
|
|
|
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps) |
|
|
|
|
|
class LoadImages: |
|
|
|
def __init__( |
|
self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1 |
|
): |
|
|
|
if isinstance(path, str) and Path(path).suffix == ".txt": |
|
path = Path(path).read_text().rsplit() |
|
files = [] |
|
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: |
|
p = str(Path(p).resolve()) |
|
if "*" in p: |
|
files.extend(sorted(glob.glob(p, recursive=True))) |
|
elif os.path.isdir(p): |
|
files.extend(sorted(glob.glob(os.path.join(p, "*.*")))) |
|
elif os.path.isfile(p): |
|
files.append(p) |
|
else: |
|
raise FileNotFoundError(f"{p} does not exist") |
|
|
|
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS] |
|
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS] |
|
ni, nv = len(images), len(videos) |
|
|
|
self.imgsz = imgsz |
|
self.stride = stride |
|
self.files = images + videos |
|
self.nf = ni + nv |
|
self.video_flag = [False] * ni + [True] * nv |
|
self.mode = "image" |
|
self.auto = auto |
|
self.transforms = transforms |
|
self.vid_stride = vid_stride |
|
self.bs = 1 |
|
if any(videos): |
|
self.orientation = None |
|
self._new_video(videos[0]) |
|
else: |
|
self.cap = None |
|
if self.nf == 0: |
|
raise FileNotFoundError( |
|
f"No images or videos found in {p}. " |
|
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" |
|
) |
|
|
|
def __iter__(self): |
|
self.count = 0 |
|
return self |
|
|
|
def __next__(self): |
|
if self.count == self.nf: |
|
raise StopIteration |
|
path = self.files[self.count] |
|
|
|
if self.video_flag[self.count]: |
|
|
|
self.mode = "video" |
|
for _ in range(self.vid_stride): |
|
self.cap.grab() |
|
success, im0 = self.cap.retrieve() |
|
while not success: |
|
self.count += 1 |
|
self.cap.release() |
|
if self.count == self.nf: |
|
raise StopIteration |
|
path = self.files[self.count] |
|
self._new_video(path) |
|
success, im0 = self.cap.read() |
|
|
|
self.frame += 1 |
|
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: " |
|
|
|
else: |
|
|
|
self.count += 1 |
|
im0 = cv2.imread(path) |
|
if im0 is None: |
|
raise FileNotFoundError(f"Image Not Found {path}") |
|
s = f"image {self.count}/{self.nf} {path}: " |
|
|
|
if self.transforms: |
|
im = self.transforms(im0) |
|
else: |
|
im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0) |
|
im = im.transpose((2, 0, 1))[::-1] |
|
im = np.ascontiguousarray(im) |
|
|
|
return path, im, im0, self.cap, s |
|
|
|
def _new_video(self, path): |
|
|
|
self.frame = 0 |
|
self.cap = cv2.VideoCapture(path) |
|
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) |
|
if hasattr(cv2, "CAP_PROP_ORIENTATION_META"): |
|
self.orientation = int( |
|
self.cap.get(cv2.CAP_PROP_ORIENTATION_META) |
|
) |
|
|
|
|
|
|
|
def _cv2_rotate(self, im): |
|
|
|
if self.orientation == 0: |
|
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE) |
|
elif self.orientation == 180: |
|
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE) |
|
elif self.orientation == 90: |
|
return cv2.rotate(im, cv2.ROTATE_180) |
|
return im |
|
|
|
def __len__(self): |
|
return self.nf |
|
|
|
|
|
class LetterBox: |
|
"""Resize image and padding for detection, instance segmentation, pose""" |
|
|
|
def __init__( |
|
self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32 |
|
): |
|
self.new_shape = new_shape |
|
self.auto = auto |
|
self.scaleFill = scaleFill |
|
self.scaleup = scaleup |
|
self.stride = stride |
|
|
|
def __call__(self, labels=None, image=None): |
|
if labels is None: |
|
labels = {} |
|
img = labels.get("img") if image is None else image |
|
shape = img.shape[:2] |
|
new_shape = labels.pop("rect_shape", self.new_shape) |
|
if isinstance(new_shape, int): |
|
new_shape = (new_shape, new_shape) |
|
|
|
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) |
|
|
|
if not self.scaleup: |
|
r = min(r, 1.0) |
|
|
|
|
|
ratio = r, r |
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) |
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] |
|
if self.auto: |
|
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) |
|
elif self.scaleFill: |
|
dw, dh = 0.0, 0.0 |
|
new_unpad = (new_shape[1], new_shape[0]) |
|
ratio = ( |
|
new_shape[1] / shape[1], |
|
new_shape[0] / shape[0], |
|
) |
|
|
|
dw /= 2 |
|
dh /= 2 |
|
if labels.get("ratio_pad"): |
|
labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) |
|
|
|
if shape[::-1] != new_unpad: |
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) |
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) |
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) |
|
img = cv2.copyMakeBorder( |
|
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) |
|
) |
|
|
|
if len(labels): |
|
labels = self._update_labels(labels, ratio, dw, dh) |
|
labels["img"] = img |
|
labels["resized_shape"] = new_shape |
|
return labels |
|
else: |
|
return img |
|
|
|
def _update_labels(self, labels, ratio, padw, padh): |
|
"""Update labels""" |
|
labels["instances"].convert_bbox(format="xyxy") |
|
labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) |
|
labels["instances"].scale(*ratio) |
|
labels["instances"].add_padding(padw, padh) |
|
return labels |
|
|
|
|
|
class Annotator: |
|
|
|
def __init__( |
|
self, |
|
im, |
|
line_width=None, |
|
font_size=None, |
|
font="Arial.ttf", |
|
pil=False, |
|
example="abc", |
|
): |
|
assert ( |
|
im.data.contiguous |
|
), "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images." |
|
|
|
non_ascii = not is_ascii(example) |
|
self.pil = pil or non_ascii |
|
if self.pil: |
|
self.pil_9_2_0_check = check_version( |
|
pil_version, "9.2.0" |
|
) |
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
|
self.draw = ImageDraw.Draw(self.im) |
|
self.font = ImageFont.load_default() |
|
else: |
|
self.im = im |
|
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) |
|
|
|
def box_label( |
|
self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255) |
|
): |
|
|
|
if isinstance(box, torch.Tensor): |
|
box = box.tolist() |
|
if self.pil or not is_ascii(label): |
|
self.draw.rectangle(box, width=self.lw, outline=color) |
|
if label: |
|
if self.pil_9_2_0_check: |
|
_, _, w, h = self.font.getbbox(label) |
|
else: |
|
w, h = self.font.getsize( |
|
label |
|
) |
|
outside = box[1] - h >= 0 |
|
self.draw.rectangle( |
|
( |
|
box[0], |
|
box[1] - h if outside else box[1], |
|
box[0] + w + 1, |
|
box[1] + 1 if outside else box[1] + h + 1, |
|
), |
|
fill=color, |
|
) |
|
|
|
self.draw.text( |
|
(box[0], box[1] - h if outside else box[1]), |
|
label, |
|
fill=txt_color, |
|
font=self.font, |
|
) |
|
else: |
|
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) |
|
cv2.rectangle( |
|
self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA |
|
) |
|
if label: |
|
tf = max(self.lw - 1, 1) |
|
|
|
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] |
|
outside = p1[1] - h >= 3 |
|
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 |
|
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) |
|
cv2.putText( |
|
self.im, |
|
label, |
|
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2), |
|
0, |
|
self.lw / 3, |
|
txt_color, |
|
thickness=tf, |
|
lineType=cv2.LINE_AA, |
|
) |
|
|
|
def rectangle(self, xy, fill=None, outline=None, width=1): |
|
|
|
self.draw.rectangle(xy, fill, outline, width) |
|
|
|
def text(self, xy, text, txt_color=(255, 255, 255), anchor="top"): |
|
|
|
if anchor == "bottom": |
|
w, h = self.font.getsize(text) |
|
xy[1] += 1 - h |
|
self.draw.text(xy, text, fill=txt_color, font=self.font) |
|
|
|
def fromarray(self, im): |
|
|
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) |
|
self.draw = ImageDraw.Draw(self.im) |
|
|
|
def result(self): |
|
|
|
return np.asarray(self.im) |
|
|
|
|
|
def non_max_suppression( |
|
prediction, |
|
conf_thres=0.25, |
|
iou_thres=0.45, |
|
classes=None, |
|
agnostic=False, |
|
multi_label=False, |
|
labels=(), |
|
max_det=300, |
|
nm=0, |
|
): |
|
|
|
assert ( |
|
0 <= conf_thres <= 1 |
|
), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" |
|
assert ( |
|
0 <= iou_thres <= 1 |
|
), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" |
|
|
|
if isinstance(prediction, (list, tuple)): |
|
prediction = prediction[0] |
|
device = prediction.device |
|
mps = "mps" in device.type |
|
if mps: |
|
prediction = prediction.cpu() |
|
bs = prediction.shape[0] |
|
nc = prediction.shape[1] - nm - 4 |
|
mi = 4 + nc |
|
xc = prediction[:, 4:mi].amax(1) > conf_thres |
|
|
|
|
|
|
|
max_wh = 7680 |
|
max_nms = 30000 |
|
time_limit = 0.5 + 0.05 * bs |
|
redundant = True |
|
multi_label &= nc > 1 |
|
merge = False |
|
|
|
t = time.time() |
|
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs |
|
for xi, x in enumerate(prediction): |
|
|
|
|
|
x = x.transpose(0, -1)[xc[xi]] |
|
|
|
|
|
if labels and len(labels[xi]): |
|
lb = labels[xi] |
|
v = torch.zeros((len(lb), nc + nm + 5), device=x.device) |
|
v[:, :4] = lb[:, 1:5] |
|
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 |
|
x = torch.cat((x, v), 0) |
|
|
|
|
|
if not x.shape[0]: |
|
continue |
|
|
|
|
|
box, cls, mask = x.split((4, nc, nm), 1) |
|
|
|
box = xywh2xyxy(box) |
|
if multi_label: |
|
i, j = (cls > conf_thres).nonzero(as_tuple=False).T |
|
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) |
|
else: |
|
conf, j = cls.max(1, keepdim=True) |
|
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] |
|
|
|
|
|
if classes is not None: |
|
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] |
|
|
|
|
|
n = x.shape[0] |
|
if not n: |
|
continue |
|
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] |
|
|
|
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) |
|
|
|
boxes, scores = x[:, :4] + c, x[:, 4] |
|
i = torchvision.ops.nms(boxes, scores, iou_thres) |
|
i = i[:max_det] |
|
if merge and (1 < n < 3e3): |
|
|
|
iou = box_iou(boxes[i], boxes) > iou_thres |
|
weights = iou * scores[None] |
|
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum( |
|
1, keepdim=True |
|
) |
|
if redundant: |
|
i = i[iou.sum(1) > 1] |
|
|
|
output[xi] = x[i] |
|
if mps: |
|
output[xi] = output[xi].to(device) |
|
if (time.time() - t) > time_limit: |
|
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") |
|
break |
|
|
|
return output |
|
|
|
|
|
class Colors: |
|
|
|
def __init__(self): |
|
|
|
hexs = ( |
|
"FF3838", |
|
"FF9D97", |
|
"FF701F", |
|
"FFB21D", |
|
"CFD231", |
|
"48F90A", |
|
"92CC17", |
|
"3DDB86", |
|
"1A9334", |
|
"00D4BB", |
|
"2C99A8", |
|
"00C2FF", |
|
"344593", |
|
"6473FF", |
|
"0018EC", |
|
"8438FF", |
|
"520085", |
|
"CB38FF", |
|
"FF95C8", |
|
"FF37C7", |
|
) |
|
self.palette = [self.hex2rgb(f"#{c}") for c in hexs] |
|
self.n = len(self.palette) |
|
|
|
def __call__(self, i, bgr=False): |
|
c = self.palette[int(i) % self.n] |
|
return (c[2], c[1], c[0]) if bgr else c |
|
|
|
@staticmethod |
|
def hex2rgb(h): |
|
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
|
colors = Colors() |
|
|
|
|
|
def threaded(func): |
|
|
|
def wrapper(*args, **kwargs): |
|
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) |
|
thread.start() |
|
return thread |
|
|
|
return wrapper |
|
|
|
|
|
def plot_images( |
|
images, |
|
batch_idx, |
|
cls, |
|
bboxes, |
|
masks=np.zeros(0, dtype=np.uint8), |
|
paths=None, |
|
fname="images.jpg", |
|
names=None, |
|
): |
|
|
|
if isinstance(images, torch.Tensor): |
|
images = images.cpu().float().numpy() |
|
if isinstance(cls, torch.Tensor): |
|
cls = cls.cpu().numpy() |
|
if isinstance(bboxes, torch.Tensor): |
|
bboxes = bboxes.cpu().numpy() |
|
if isinstance(masks, torch.Tensor): |
|
masks = masks.cpu().numpy().astype(int) |
|
if isinstance(batch_idx, torch.Tensor): |
|
batch_idx = batch_idx.cpu().numpy() |
|
|
|
max_size = 1920 |
|
max_subplots = 16 |
|
bs, _, h, w = images.shape |
|
bs = min(bs, max_subplots) |
|
ns = np.ceil(bs**0.5) |
|
if np.max(images[0]) <= 1: |
|
images *= 255 |
|
|
|
|
|
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) |
|
for i, im in enumerate(images): |
|
if i == max_subplots: |
|
break |
|
x, y = int(w * (i // ns)), int(h * (i % ns)) |
|
im = im.transpose(1, 2, 0) |
|
mosaic[y : y + h, x : x + w, :] = im |
|
|
|
|
|
scale = max_size / ns / max(h, w) |
|
if scale < 1: |
|
h = math.ceil(scale * h) |
|
w = math.ceil(scale * w) |
|
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) |
|
|
|
|
|
fs = int((h + w) * ns * 0.01) |
|
annotator = Annotator( |
|
mosaic, line_width=2, font_size=fs, pil=True, example=names |
|
) |
|
for i in range(i + 1): |
|
x, y = int(w * (i // ns)), int(h * (i % ns)) |
|
annotator.rectangle( |
|
[x, y, x + w, y + h], None, (255, 255, 255), width=2 |
|
) |
|
if paths: |
|
annotator.text( |
|
|
|
(x + 5, y + 5 + h), |
|
text=Path(paths[i]).name[:40], |
|
txt_color=(220, 220, 220), |
|
) |
|
if len(cls) > 0: |
|
idx = batch_idx == i |
|
|
|
boxes = xywh2xyxy(bboxes[idx, :4]).T |
|
classes = cls[idx].astype("int") |
|
labels = bboxes.shape[1] == 4 |
|
|
|
conf = None if labels else bboxes[idx, 4] |
|
|
|
if boxes.shape[1]: |
|
if boxes.max() <= 1.01: |
|
boxes[[0, 2]] *= w |
|
boxes[[1, 3]] *= h |
|
elif scale < 1: |
|
boxes *= scale |
|
boxes[[0, 2]] += x |
|
boxes[[1, 3]] += y |
|
for j, box in enumerate(boxes.T.tolist()): |
|
c = classes[j] |
|
color = colors(c) |
|
c = names[c] if names else c |
|
if labels or conf[j] > 0.25: |
|
label = f"{c}" if labels else f"{c} {conf[j]:.1f}" |
|
annotator.box_label(box, label, color=color) |
|
annotator.im.save(fname) |
|
|
|
|
|
def output_to_target(output, max_det=300): |
|
|
|
targets = [] |
|
for i, o in enumerate(output): |
|
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1) |
|
j = torch.full((conf.shape[0], 1), i) |
|
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1)) |
|
targets = torch.cat(targets, 0).numpy() |
|
return targets[:, 0], targets[:, 1], targets[:, 2:] |
|
|
|
|
|
def is_ascii(s=""): |
|
|
|
s = str(s) |
|
return len(s.encode().decode("ascii", "ignore")) == len(s) |
|
|
|
|
|
def xyxy2xywh(x): |
|
""" |
|
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format. |
|
|
|
Args: |
|
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. |
|
Returns: |
|
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format. |
|
""" |
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) |
|
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 |
|
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 |
|
y[..., 2] = x[..., 2] - x[..., 0] |
|
y[..., 3] = x[..., 3] - x[..., 1] |
|
return y |
|
|
|
|
|
def xywh2xyxy(x): |
|
|
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) |
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 |
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 |
|
y[:, 2] = x[:, 0] + x[:, 2] / 2 |
|
y[:, 3] = x[:, 1] + x[:, 3] / 2 |
|
return y |
|
|
|
|
|
def check_det_dataset(dataset, autodownload=True): |
|
|
|
data = dataset |
|
|
|
extract_dir = '' |
|
|
|
|
|
if isinstance(data, (str, Path)): |
|
data = yaml_load(data, append_filename=True) |
|
|
|
|
|
if isinstance(data['names'], (list, tuple)): |
|
data['names'] = dict(enumerate(data['names'])) |
|
data['nc'] = len(data['names']) |
|
|
|
|
|
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) |
|
|
|
DATASETS_DIR = os.path.abspath('.') |
|
if not path.is_absolute(): |
|
path = (DATASETS_DIR / path).resolve() |
|
data['path'] = path |
|
for k in 'train', 'val', 'test': |
|
if data.get(k): |
|
if isinstance(data[k], str): |
|
x = (path / data[k]).resolve() |
|
if not x.exists() and data[k].startswith('../'): |
|
x = (path / data[k][3:]).resolve() |
|
data[k] = str(x) |
|
else: |
|
data[k] = [str((path / x).resolve()) for x in data[k]] |
|
|
|
|
|
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download')) |
|
if val: |
|
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] |
|
if not all(x.exists() for x in val): |
|
msg = f"\nDataset '{dataset}' not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()] |
|
if s and autodownload: |
|
LOGGER.warning(msg) |
|
else: |
|
raise FileNotFoundError(msg) |
|
t = time.time() |
|
if s.startswith('bash '): |
|
LOGGER.info(f'Running {s} ...') |
|
r = os.system(s) |
|
else: |
|
r = exec(s, {'yaml': data}) |
|
dt = f'({round(time.time() - t, 1)}s)' |
|
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" |
|
LOGGER.info(f"Dataset download {s}\n") |
|
|
|
return data |
|
|
|
|
|
def yaml_load(file='data.yaml', append_filename=False): |
|
""" |
|
Load YAML data from a file. |
|
|
|
Args: |
|
file (str, optional): File name. Default is 'data.yaml'. |
|
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False. |
|
|
|
Returns: |
|
dict: YAML data and file name. |
|
""" |
|
with open(file, errors='ignore', encoding='utf-8') as f: |
|
|
|
s = f.read() |
|
if not s.isprintable(): |
|
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s) |
|
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s) |
|
|
|
|
|
class IterableSimpleNamespace(SimpleNamespace): |
|
""" |
|
Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops |
|
""" |
|
|
|
def __iter__(self): |
|
return iter(vars(self).items()) |
|
|
|
def __str__(self): |
|
return '\n'.join(f"{k}={v}" for k, v in vars(self).items()) |
|
|
|
def get(self, key, default=None): |
|
return getattr(self, key, default) |
|
|
|
|
|
def colorstr(*input): |
|
|
|
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) |
|
colors = { |
|
"black": "\033[30m", |
|
"red": "\033[31m", |
|
"green": "\033[32m", |
|
"yellow": "\033[33m", |
|
"blue": "\033[34m", |
|
"magenta": "\033[35m", |
|
"cyan": "\033[36m", |
|
"white": "\033[37m", |
|
"bright_black": "\033[90m", |
|
"bright_red": "\033[91m", |
|
"bright_green": "\033[92m", |
|
"bright_yellow": "\033[93m", |
|
"bright_blue": "\033[94m", |
|
"bright_magenta": "\033[95m", |
|
"bright_cyan": "\033[96m", |
|
"bright_white": "\033[97m", |
|
"end": "\033[0m", |
|
"bold": "\033[1m", |
|
"underline": "\033[4m"} |
|
return "".join(colors[x] for x in args) + f"{string}" + colors["end"] |
|
|
|
|
|
def seed_worker(worker_id): |
|
|
|
worker_seed = torch.initial_seed() % 2 ** 32 |
|
np.random.seed(worker_seed) |
|
random.seed(worker_seed) |
|
|
|
|
|
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"): |
|
assert mode in ["train", "val"] |
|
shuffle = mode == "train" |
|
if cfg.rect and shuffle: |
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") |
|
shuffle = False |
|
dataset = YOLODataset( |
|
img_path=img_path, |
|
imgsz=cfg.imgsz, |
|
batch_size=batch, |
|
augment=mode == "train", |
|
hyp=cfg, |
|
rect=cfg.rect or rect, |
|
cache=cfg.cache or None, |
|
single_cls=cfg.single_cls or False, |
|
stride=int(stride), |
|
pad=0.0 if mode == "train" else 0.5, |
|
prefix=colorstr(f"{mode}: "), |
|
use_segments=cfg.task == "segment", |
|
use_keypoints=cfg.task == "keypoint", |
|
names=names) |
|
|
|
batch = min(batch, len(dataset)) |
|
nd = torch.cuda.device_count() |
|
workers = cfg.workers if mode == "train" else cfg.workers * 2 |
|
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) |
|
|
|
if rank == -1: |
|
sampler = None |
|
if cfg.image_weights or cfg.close_mosaic: |
|
loader = DataLoader |
|
generator = torch.Generator() |
|
generator.manual_seed(6148914691236517205) |
|
return loader(dataset=dataset, |
|
batch_size=batch, |
|
shuffle=shuffle and sampler is None, |
|
num_workers=nw, |
|
sampler=sampler, |
|
pin_memory=PIN_MEMORY, |
|
collate_fn=getattr(dataset, "collate_fn", None), |
|
worker_init_fn=seed_worker, |
|
generator=generator), dataset |
|
|
|
|
|
class BaseDataset(Dataset): |
|
"""Base Dataset. |
|
Args: |
|
img_path (str): image path. |
|
pipeline (dict): a dict of image transforms. |
|
label_path (str): label path, this can also be an ann_file or other custom label path. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
img_path, |
|
imgsz=640, |
|
cache=False, |
|
augment=True, |
|
hyp=None, |
|
prefix="", |
|
rect=False, |
|
batch_size=None, |
|
stride=32, |
|
pad=0.5, |
|
single_cls=False, |
|
): |
|
super().__init__() |
|
self.img_path = img_path |
|
self.imgsz = imgsz |
|
self.augment = augment |
|
self.single_cls = single_cls |
|
self.prefix = prefix |
|
self.im_files = self.get_img_files(self.img_path) |
|
self.labels = self.get_labels() |
|
self.ni = len(self.labels) |
|
|
|
|
|
self.rect = rect |
|
self.batch_size = batch_size |
|
self.stride = stride |
|
self.pad = pad |
|
if self.rect: |
|
assert self.batch_size is not None |
|
self.set_rectangle() |
|
|
|
|
|
self.ims = [None] * self.ni |
|
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] |
|
if cache: |
|
self.cache_images(cache) |
|
|
|
|
|
self.transforms = self.build_transforms(hyp=hyp) |
|
|
|
def get_img_files(self, img_path): |
|
"""Read image files.""" |
|
try: |
|
f = [] |
|
for p in img_path if isinstance(img_path, list) else [img_path]: |
|
p = Path(p) |
|
if p.is_dir(): |
|
f += glob.glob(str(p / "**" / "*.*"), recursive=True) |
|
|
|
elif p.is_file(): |
|
with open(p) as t: |
|
t = t.read().strip().splitlines() |
|
parent = str(p.parent) + os.sep |
|
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] |
|
|
|
else: |
|
raise FileNotFoundError(f"{self.prefix}{p} does not exist") |
|
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) |
|
|
|
assert im_files, f"{self.prefix}No images found" |
|
except Exception as e: |
|
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n") from e |
|
return im_files |
|
|
|
def load_image(self, i): |
|
|
|
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] |
|
if im is None: |
|
if fn.exists(): |
|
im = np.load(fn) |
|
else: |
|
im = cv2.imread(f) |
|
if im is None: |
|
raise FileNotFoundError(f"Image Not Found {f}") |
|
h0, w0 = im.shape[:2] |
|
r = self.imgsz / max(h0, w0) |
|
if r != 1: |
|
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA |
|
im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp) |
|
return im, (h0, w0), im.shape[:2] |
|
return self.ims[i], self.im_hw0[i], self.im_hw[i] |
|
|
|
def cache_images(self, cache): |
|
|
|
gb = 0 |
|
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni |
|
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image |
|
with ThreadPool(NUM_THREADS) as pool: |
|
results = pool.imap(fcn, range(self.ni)) |
|
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT) |
|
for i, x in pbar: |
|
if cache == "disk": |
|
gb += self.npy_files[i].stat().st_size |
|
else: |
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = x |
|
gb += self.ims[i].nbytes |
|
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})" |
|
pbar.close() |
|
|
|
def cache_images_to_disk(self, i): |
|
|
|
f = self.npy_files[i] |
|
if not f.exists(): |
|
np.save(f.as_posix(), cv2.imread(self.im_files[i])) |
|
|
|
def set_rectangle(self): |
|
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) |
|
nb = bi[-1] + 1 |
|
|
|
s = np.array([x.pop("shape") for x in self.labels]) |
|
ar = s[:, 0] / s[:, 1] |
|
irect = ar.argsort() |
|
self.im_files = [self.im_files[i] for i in irect] |
|
self.labels = [self.labels[i] for i in irect] |
|
ar = ar[irect] |
|
|
|
|
|
shapes = [[1, 1]] * nb |
|
for i in range(nb): |
|
ari = ar[bi == i] |
|
mini, maxi = ari.min(), ari.max() |
|
if maxi < 1: |
|
shapes[i] = [maxi, 1] |
|
elif mini > 1: |
|
shapes[i] = [1, 1 / mini] |
|
|
|
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride |
|
self.batch = bi |
|
|
|
def __getitem__(self, index): |
|
return self.transforms(self.get_label_info(index)) |
|
|
|
def get_label_info(self, index): |
|
label = self.labels[index].copy() |
|
label.pop("shape", None) |
|
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) |
|
label["ratio_pad"] = ( |
|
label["resized_shape"][0] / label["ori_shape"][0], |
|
label["resized_shape"][1] / label["ori_shape"][1], |
|
) |
|
if self.rect: |
|
label["rect_shape"] = self.batch_shapes[self.batch[index]] |
|
label = self.update_labels_info(label) |
|
return label |
|
|
|
def __len__(self): |
|
return len(self.labels) |
|
|
|
def update_labels_info(self, label): |
|
"""custom your label format here""" |
|
return label |
|
|
|
def build_transforms(self, hyp=None): |
|
"""Users can custom augmentations here |
|
like: |
|
if self.augment: |
|
# training transforms |
|
return Compose([]) |
|
else: |
|
# val transforms |
|
return Compose([]) |
|
""" |
|
raise NotImplementedError |
|
|
|
def get_labels(self): |
|
"""Users can custom their own format here. |
|
Make sure your output is a list with each element like below: |
|
dict( |
|
im_file=im_file, |
|
shape=shape, # format: (height, width) |
|
cls=cls, |
|
bboxes=bboxes, # xywh |
|
segments=segments, # xy |
|
keypoints=keypoints, # xy |
|
normalized=True, # or False |
|
bbox_format="xyxy", # or xywh, ltwh |
|
) |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
def img2label_paths(img_paths): |
|
|
|
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" |
|
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] |
|
|
|
|
|
def get_hash(paths): |
|
|
|
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) |
|
h = hashlib.md5(str(size).encode()) |
|
h.update("".join(paths).encode()) |
|
return h.hexdigest() |
|
|
|
|
|
class Compose: |
|
|
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, data): |
|
for t in self.transforms: |
|
data = t(data) |
|
return data |
|
|
|
def append(self, transform): |
|
self.transforms.append(transform) |
|
|
|
def tolist(self): |
|
return self.transforms |
|
|
|
def __repr__(self): |
|
format_string = f"{self.__class__.__name__}(" |
|
for t in self.transforms: |
|
format_string += "\n" |
|
format_string += f" {t}" |
|
format_string += "\n)" |
|
return format_string |
|
|
|
|
|
class Format: |
|
|
|
def __init__(self, |
|
bbox_format="xywh", |
|
normalize=True, |
|
return_mask=False, |
|
return_keypoint=False, |
|
mask_ratio=4, |
|
mask_overlap=True, |
|
batch_idx=True): |
|
self.bbox_format = bbox_format |
|
self.normalize = normalize |
|
self.return_mask = return_mask |
|
self.return_keypoint = return_keypoint |
|
self.mask_ratio = mask_ratio |
|
self.mask_overlap = mask_overlap |
|
self.batch_idx = batch_idx |
|
|
|
def __call__(self, labels): |
|
img = labels.pop("img") |
|
h, w = img.shape[:2] |
|
cls = labels.pop("cls") |
|
instances = labels.pop("instances") |
|
instances.convert_bbox(format=self.bbox_format) |
|
instances.denormalize(w, h) |
|
nl = len(instances) |
|
|
|
if self.normalize: |
|
instances.normalize(w, h) |
|
labels["img"] = self._format_img(img) |
|
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) |
|
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) |
|
if self.return_keypoint: |
|
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2)) |
|
|
|
if self.batch_idx: |
|
labels["batch_idx"] = torch.zeros(nl) |
|
return labels |
|
|
|
def _format_img(self, img): |
|
if len(img.shape) < 3: |
|
img = np.expand_dims(img, -1) |
|
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1]) |
|
img = torch.from_numpy(img) |
|
return img |
|
|
|
class Bboxes: |
|
"""Now only numpy is supported""" |
|
|
|
def __init__(self, bboxes, format="xyxy") -> None: |
|
assert format in _formats |
|
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes |
|
assert bboxes.ndim == 2 |
|
assert bboxes.shape[1] == 4 |
|
self.bboxes = bboxes |
|
self.format = format |
|
|
|
def convert(self, format): |
|
assert format in _formats |
|
if self.format == format: |
|
return |
|
elif self.format == "xyxy": |
|
if format == "xywh": |
|
bboxes = xyxy2xywh(self.bboxes) |
|
elif self.format == "xywh": |
|
if format == "xyxy": |
|
bboxes = xywh2xyxy(self.bboxes) |
|
self.bboxes = bboxes |
|
self.format = format |
|
|
|
def areas(self): |
|
self.convert("xyxy") |
|
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) |
|
|
|
def mul(self, scale): |
|
""" |
|
Args: |
|
scale (tuple | List | int): the scale for four coords. |
|
""" |
|
assert isinstance(scale, (tuple, list)) |
|
assert len(scale) == 4 |
|
self.bboxes[:, 0] *= scale[0] |
|
self.bboxes[:, 1] *= scale[1] |
|
self.bboxes[:, 2] *= scale[2] |
|
self.bboxes[:, 3] *= scale[3] |
|
|
|
def add(self, offset): |
|
""" |
|
Args: |
|
offset (tuple | List | int): the offset for four coords. |
|
""" |
|
assert isinstance(offset, (tuple, list)) |
|
assert len(offset) == 4 |
|
self.bboxes[:, 0] += offset[0] |
|
self.bboxes[:, 1] += offset[1] |
|
self.bboxes[:, 2] += offset[2] |
|
self.bboxes[:, 3] += offset[3] |
|
|
|
def __len__(self): |
|
return len(self.bboxes) |
|
|
|
@classmethod |
|
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes": |
|
""" |
|
Concatenates a list of Boxes into a single Bboxes |
|
|
|
Arguments: |
|
boxes_list (list[Bboxes]) |
|
|
|
Returns: |
|
Bboxes: the concatenated Boxes |
|
""" |
|
assert isinstance(boxes_list, (list, tuple)) |
|
if not boxes_list: |
|
return cls(np.empty(0)) |
|
assert all(isinstance(box, Bboxes) for box in boxes_list) |
|
|
|
if len(boxes_list) == 1: |
|
return boxes_list[0] |
|
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis)) |
|
|
|
def __getitem__(self, index) -> "Bboxes": |
|
""" |
|
Args: |
|
index: int, slice, or a BoolArray |
|
|
|
Returns: |
|
Bboxes: Create a new :class:`Bboxes` by indexing. |
|
""" |
|
if isinstance(index, int): |
|
return Bboxes(self.bboxes[index].view(1, -1)) |
|
b = self.bboxes[index] |
|
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!" |
|
return Bboxes(b) |
|
|
|
|
|
def resample_segments(segments, n=1000): |
|
""" |
|
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. |
|
|
|
Args: |
|
segments (list): a list of (n,2) arrays, where n is the number of points in the segment. |
|
n (int): number of points to resample the segment to. Defaults to 1000 |
|
|
|
Returns: |
|
segments (list): the resampled segments. |
|
""" |
|
for i, s in enumerate(segments): |
|
s = np.concatenate((s, s[0:1, :]), axis=0) |
|
x = np.linspace(0, len(s) - 1, n) |
|
xp = np.arange(len(s)) |
|
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T |
|
return segments |
|
|
|
|
|
class Instances: |
|
|
|
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None: |
|
""" |
|
Args: |
|
bboxes (ndarray): bboxes with shape [N, 4]. |
|
segments (list | ndarray): segments. |
|
keypoints (ndarray): keypoints with shape [N, 17, 2]. |
|
""" |
|
if segments is None: |
|
segments = [] |
|
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format) |
|
self.keypoints = keypoints |
|
self.normalized = normalized |
|
|
|
if len(segments) > 0: |
|
|
|
segments = resample_segments(segments) |
|
|
|
segments = np.stack(segments, axis=0) |
|
else: |
|
segments = np.zeros((0, 1000, 2), dtype=np.float32) |
|
self.segments = segments |
|
|
|
def convert_bbox(self, format): |
|
self._bboxes.convert(format=format) |
|
|
|
def bbox_areas(self): |
|
self._bboxes.areas() |
|
|
|
def scale(self, scale_w, scale_h, bbox_only=False): |
|
"""this might be similar with denormalize func but without normalized sign""" |
|
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h)) |
|
if bbox_only: |
|
return |
|
self.segments[..., 0] *= scale_w |
|
self.segments[..., 1] *= scale_h |
|
if self.keypoints is not None: |
|
self.keypoints[..., 0] *= scale_w |
|
self.keypoints[..., 1] *= scale_h |
|
|
|
def denormalize(self, w, h): |
|
if not self.normalized: |
|
return |
|
self._bboxes.mul(scale=(w, h, w, h)) |
|
self.segments[..., 0] *= w |
|
self.segments[..., 1] *= h |
|
if self.keypoints is not None: |
|
self.keypoints[..., 0] *= w |
|
self.keypoints[..., 1] *= h |
|
self.normalized = False |
|
|
|
def normalize(self, w, h): |
|
if self.normalized: |
|
return |
|
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h)) |
|
self.segments[..., 0] /= w |
|
self.segments[..., 1] /= h |
|
if self.keypoints is not None: |
|
self.keypoints[..., 0] /= w |
|
self.keypoints[..., 1] /= h |
|
self.normalized = True |
|
|
|
def add_padding(self, padw, padh): |
|
|
|
assert not self.normalized, "you should add padding with absolute coordinates." |
|
self._bboxes.add(offset=(padw, padh, padw, padh)) |
|
self.segments[..., 0] += padw |
|
self.segments[..., 1] += padh |
|
if self.keypoints is not None: |
|
self.keypoints[..., 0] += padw |
|
self.keypoints[..., 1] += padh |
|
|
|
def __getitem__(self, index) -> "Instances": |
|
""" |
|
Args: |
|
index: int, slice, or a BoolArray |
|
|
|
Returns: |
|
Instances: Create a new :class:`Instances` by indexing. |
|
""" |
|
segments = self.segments[index] if len(self.segments) else self.segments |
|
keypoints = self.keypoints[index] if self.keypoints is not None else None |
|
bboxes = self.bboxes[index] |
|
bbox_format = self._bboxes.format |
|
return Instances( |
|
bboxes=bboxes, |
|
segments=segments, |
|
keypoints=keypoints, |
|
bbox_format=bbox_format, |
|
normalized=self.normalized, |
|
) |
|
|
|
def flipud(self, h): |
|
if self._bboxes.format == "xyxy": |
|
y1 = self.bboxes[:, 1].copy() |
|
y2 = self.bboxes[:, 3].copy() |
|
self.bboxes[:, 1] = h - y2 |
|
self.bboxes[:, 3] = h - y1 |
|
else: |
|
self.bboxes[:, 1] = h - self.bboxes[:, 1] |
|
self.segments[..., 1] = h - self.segments[..., 1] |
|
if self.keypoints is not None: |
|
self.keypoints[..., 1] = h - self.keypoints[..., 1] |
|
|
|
def fliplr(self, w): |
|
if self._bboxes.format == "xyxy": |
|
x1 = self.bboxes[:, 0].copy() |
|
x2 = self.bboxes[:, 2].copy() |
|
self.bboxes[:, 0] = w - x2 |
|
self.bboxes[:, 2] = w - x1 |
|
else: |
|
self.bboxes[:, 0] = w - self.bboxes[:, 0] |
|
self.segments[..., 0] = w - self.segments[..., 0] |
|
if self.keypoints is not None: |
|
self.keypoints[..., 0] = w - self.keypoints[..., 0] |
|
|
|
def clip(self, w, h): |
|
ori_format = self._bboxes.format |
|
self.convert_bbox(format="xyxy") |
|
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) |
|
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) |
|
if ori_format != "xyxy": |
|
self.convert_bbox(format=ori_format) |
|
self.segments[..., 0] = self.segments[..., 0].clip(0, w) |
|
self.segments[..., 1] = self.segments[..., 1].clip(0, h) |
|
if self.keypoints is not None: |
|
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w) |
|
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) |
|
|
|
def update(self, bboxes, segments=None, keypoints=None): |
|
new_bboxes = Bboxes(bboxes, format=self._bboxes.format) |
|
self._bboxes = new_bboxes |
|
if segments is not None: |
|
self.segments = segments |
|
if keypoints is not None: |
|
self.keypoints = keypoints |
|
|
|
def __len__(self): |
|
return len(self.bboxes) |
|
|
|
@classmethod |
|
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances": |
|
""" |
|
Concatenates a list of Boxes into a single Bboxes |
|
|
|
Arguments: |
|
instances_list (list[Bboxes]) |
|
axis |
|
|
|
Returns: |
|
Boxes: the concatenated Boxes |
|
""" |
|
assert isinstance(instances_list, (list, tuple)) |
|
if not instances_list: |
|
return cls(np.empty(0)) |
|
assert all(isinstance(instance, Instances) for instance in instances_list) |
|
|
|
if len(instances_list) == 1: |
|
return instances_list[0] |
|
|
|
use_keypoint = instances_list[0].keypoints is not None |
|
bbox_format = instances_list[0]._bboxes.format |
|
normalized = instances_list[0].normalized |
|
|
|
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) |
|
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) |
|
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None |
|
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized) |
|
|
|
@property |
|
def bboxes(self): |
|
return self._bboxes.bboxes |
|
|
|
|
|
def is_dir_writeable(dir_path: Union[str, Path]) -> bool: |
|
""" |
|
Check if a directory is writeable. |
|
|
|
Args: |
|
dir_path (str) or (Path): The path to the directory. |
|
|
|
Returns: |
|
bool: True if the directory is writeable, False otherwise. |
|
""" |
|
try: |
|
with tempfile.TemporaryFile(dir=dir_path): |
|
pass |
|
return True |
|
except OSError: |
|
return False |
|
|
|
|
|
class YOLODataset(BaseDataset): |
|
cache_version = '1.0.1' |
|
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] |
|
"""YOLO Dataset. |
|
Args: |
|
img_path (str): image path. |
|
prefix (str): prefix. |
|
""" |
|
|
|
def __init__(self, |
|
img_path, |
|
imgsz=640, |
|
cache=False, |
|
augment=True, |
|
hyp=None, |
|
prefix="", |
|
rect=False, |
|
batch_size=None, |
|
stride=32, |
|
pad=0.0, |
|
single_cls=False, |
|
use_segments=False, |
|
use_keypoints=False, |
|
names=None): |
|
self.use_segments = use_segments |
|
self.use_keypoints = use_keypoints |
|
self.names = names |
|
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." |
|
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls) |
|
|
|
def cache_labels(self, path=Path("./labels.cache")): |
|
|
|
if path.exists(): |
|
path.unlink() |
|
x = {"labels": []} |
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] |
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..." |
|
total = len(self.im_files) |
|
with ThreadPool(NUM_THREADS) as pool: |
|
results = pool.imap(func=verify_image_label, |
|
iterable=zip(self.im_files, self.label_files, repeat(self.prefix), |
|
repeat(self.use_keypoints), repeat(len(self.names)))) |
|
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT) |
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: |
|
nm += nm_f |
|
nf += nf_f |
|
ne += ne_f |
|
nc += nc_f |
|
if im_file: |
|
x["labels"].append( |
|
dict( |
|
im_file=im_file, |
|
shape=shape, |
|
cls=lb[:, 0:1], |
|
bboxes=lb[:, 1:], |
|
segments=segments, |
|
keypoints=keypoint, |
|
normalized=True, |
|
bbox_format="xywh")) |
|
if msg: |
|
msgs.append(msg) |
|
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" |
|
pbar.close() |
|
|
|
if msgs: |
|
LOGGER.info("\n".join(msgs)) |
|
x["hash"] = get_hash(self.label_files + self.im_files) |
|
x["results"] = nf, nm, ne, nc, len(self.im_files) |
|
x["msgs"] = msgs |
|
x["version"] = self.cache_version |
|
self.im_files = [lb["im_file"] for lb in x["labels"]] |
|
if is_dir_writeable(path.parent): |
|
np.save(str(path), x) |
|
path.with_suffix(".cache.npy").rename(path) |
|
LOGGER.info(f"{self.prefix}New cache created: {path}") |
|
else: |
|
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable") |
|
return x |
|
|
|
def get_labels(self): |
|
self.label_files = img2label_paths(self.im_files) |
|
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") |
|
try: |
|
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True |
|
assert cache["version"] == self.cache_version |
|
assert cache["hash"] == get_hash(self.label_files + self.im_files) |
|
except (FileNotFoundError, AssertionError, AttributeError): |
|
cache, exists = self.cache_labels(cache_path), False |
|
|
|
|
|
nf, nm, ne, nc, n = cache.pop("results") |
|
if exists: |
|
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" |
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) |
|
if cache["msgs"]: |
|
LOGGER.info("\n".join(cache["msgs"])) |
|
|
|
|
|
[cache.pop(k) for k in ("hash", "version", "msgs")] |
|
labels = cache["labels"] |
|
|
|
|
|
len_cls = sum(len(lb["cls"]) for lb in labels) |
|
len_boxes = sum(len(lb["bboxes"]) for lb in labels) |
|
len_segments = sum(len(lb["segments"]) for lb in labels) |
|
if len_segments and len_boxes != len_segments: |
|
LOGGER.warning( |
|
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " |
|
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " |
|
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.") |
|
for lb in labels: |
|
lb["segments"] = [] |
|
return labels |
|
|
|
|
|
def build_transforms(self, hyp=None): |
|
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) |
|
transforms.append( |
|
Format(bbox_format="xywh", |
|
normalize=True, |
|
return_mask=self.use_segments, |
|
return_keypoint=self.use_keypoints, |
|
batch_idx=True, |
|
mask_ratio=hyp.mask_ratio, |
|
mask_overlap=hyp.overlap_mask)) |
|
return transforms |
|
|
|
def close_mosaic(self, hyp): |
|
hyp.mosaic = 0.0 |
|
hyp.copy_paste = 0.0 |
|
hyp.mixup = 0.0 |
|
self.transforms = self.build_transforms(hyp) |
|
|
|
def update_labels_info(self, label): |
|
"""custom your label format here""" |
|
|
|
|
|
bboxes = label.pop("bboxes") |
|
segments = label.pop("segments") |
|
keypoints = label.pop("keypoints", None) |
|
bbox_format = label.pop("bbox_format") |
|
normalized = label.pop("normalized") |
|
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) |
|
return label |
|
|
|
@staticmethod |
|
def collate_fn(batch): |
|
new_batch = {} |
|
keys = batch[0].keys() |
|
values = list(zip(*[list(b.values()) for b in batch])) |
|
for i, k in enumerate(keys): |
|
value = values[i] |
|
if k == "img": |
|
value = torch.stack(value, 0) |
|
if k in ["masks", "keypoints", "bboxes", "cls"]: |
|
value = torch.cat(value, 0) |
|
new_batch[k] = value |
|
new_batch["batch_idx"] = list(new_batch["batch_idx"]) |
|
for i in range(len(new_batch["batch_idx"])): |
|
new_batch["batch_idx"][i] += i |
|
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) |
|
return new_batch |
|
|
|
|
|
class DFL(nn.Module): |
|
|
|
def __init__(self, c1=16): |
|
super().__init__() |
|
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) |
|
x = torch.arange(c1, dtype=torch.float) |
|
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) |
|
self.c1 = c1 |
|
|
|
def forward(self, x): |
|
b, c, a = x.shape |
|
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view( |
|
b, 4, a |
|
) |
|
|
|
|
|
def dist2bbox(distance, anchor_points, xywh=True, dim=-1): |
|
"""Transform distance(ltrb) to box(xywh or xyxy).""" |
|
lt, rb = torch.split(distance, 2, dim) |
|
x1y1 = anchor_points - lt |
|
x2y2 = anchor_points + rb |
|
if xywh: |
|
c_xy = (x1y1 + x2y2) / 2 |
|
wh = x2y2 - x1y1 |
|
return torch.cat((c_xy, wh), dim) |
|
return torch.cat((x1y1, x2y2), dim) |
|
|
|
|
|
def post_process(x): |
|
dfl = DFL(16) |
|
anchors = torch.tensor( |
|
np.load( |
|
"./anchors.npy", |
|
allow_pickle=True, |
|
) |
|
) |
|
strides = torch.tensor( |
|
np.load( |
|
"./strides.npy", |
|
allow_pickle=True, |
|
) |
|
) |
|
box, cls = torch.cat([xi.view(x[0].shape[0], 144, -1) for xi in x], 2).split( |
|
(16 * 4, 80), 1 |
|
) |
|
dbox = dist2bbox(dfl(box), anchors.unsqueeze(0), xywh=True, dim=1) * strides |
|
y = torch.cat((dbox, cls.sigmoid()), 1) |
|
return y, x |
|
|
|
|
|
def smooth(y, f=0.05): |
|
|
|
nf = round(len(y) * f * 2) // 2 + 1 |
|
p = np.ones(nf // 2) |
|
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) |
|
return np.convolve(yp, np.ones(nf) / nf, mode='valid') |
|
|
|
|
|
def compute_ap(recall, precision): |
|
""" Compute the average precision, given the recall and precision curves |
|
# Arguments |
|
recall: The recall curve (list) |
|
precision: The precision curve (list) |
|
# Returns |
|
Average precision, precision curve, recall curve |
|
""" |
|
|
|
|
|
mrec = np.concatenate(([0.0], recall, [1.0])) |
|
mpre = np.concatenate(([1.0], precision, [0.0])) |
|
|
|
|
|
mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) |
|
|
|
|
|
method = 'interp' |
|
if method == 'interp': |
|
x = np.linspace(0, 1, 101) |
|
ap = np.trapz(np.interp(x, mrec, mpre), x) |
|
else: |
|
i = np.where(mrec[1:] != mrec[:-1])[0] |
|
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) |
|
|
|
return ap, mpre, mrec |
|
|
|
|
|
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""): |
|
""" Compute the average precision, given the recall and precision curves. |
|
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. |
|
# Arguments |
|
tp: True positives (nparray, nx1 or nx10). |
|
conf: Objectness value from 0-1 (nparray). |
|
pred_cls: Predicted object classes (nparray). |
|
target_cls: True object classes (nparray). |
|
plot: Plot precision-recall curve at [email protected] |
|
save_dir: Plot save directory |
|
# Returns |
|
The average precision as computed in py-faster-rcnn. |
|
""" |
|
|
|
|
|
i = np.argsort(-conf) |
|
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] |
|
|
|
|
|
unique_classes, nt = np.unique(target_cls, return_counts=True) |
|
nc = unique_classes.shape[0] |
|
|
|
|
|
px, py = np.linspace(0, 1, 1000), [] |
|
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) |
|
for ci, c in enumerate(unique_classes): |
|
i = pred_cls == c |
|
n_l = nt[ci] |
|
n_p = i.sum() |
|
if n_p == 0 or n_l == 0: |
|
continue |
|
|
|
|
|
fpc = (1 - tp[i]).cumsum(0) |
|
tpc = tp[i].cumsum(0) |
|
|
|
|
|
recall = tpc / (n_l + eps) |
|
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) |
|
|
|
|
|
precision = tpc / (tpc + fpc) |
|
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) |
|
|
|
|
|
for j in range(tp.shape[1]): |
|
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) |
|
if plot and j == 0: |
|
py.append(np.interp(px, mrec, mpre)) |
|
|
|
|
|
f1 = 2 * p * r / (p + r + eps) |
|
names = [v for k, v in names.items() if k in unique_classes] |
|
names = dict(enumerate(names)) |
|
|
|
i = smooth(f1.mean(0), 0.1).argmax() |
|
p, r, f1 = p[:, i], r[:, i], f1[:, i] |
|
tp = (r * nt).round() |
|
fp = (tp / (p + eps) - tp).round() |
|
return tp, fp, p, r, f1, ap, unique_classes.astype(int) |
|
|
|
|
|
class Metric: |
|
|
|
def __init__(self) -> None: |
|
self.p = [] |
|
self.r = [] |
|
self.f1 = [] |
|
self.all_ap = [] |
|
self.ap_class_index = [] |
|
self.nc = 0 |
|
|
|
@property |
|
def ap50(self): |
|
"""[email protected] of all classes. |
|
Return: |
|
(nc, ) or []. |
|
""" |
|
return self.all_ap[:, 0] if len(self.all_ap) else [] |
|
|
|
@property |
|
def ap(self): |
|
"""[email protected]:0.95 |
|
Return: |
|
(nc, ) or []. |
|
""" |
|
return self.all_ap.mean(1) if len(self.all_ap) else [] |
|
|
|
@property |
|
def mp(self): |
|
"""mean precision of all classes. |
|
Return: |
|
float. |
|
""" |
|
return self.p.mean() if len(self.p) else 0.0 |
|
|
|
@property |
|
def mr(self): |
|
"""mean recall of all classes. |
|
Return: |
|
float. |
|
""" |
|
return self.r.mean() if len(self.r) else 0.0 |
|
|
|
@property |
|
def map50(self): |
|
"""Mean [email protected] of all classes. |
|
Return: |
|
float. |
|
""" |
|
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 |
|
|
|
@property |
|
def map75(self): |
|
"""Mean [email protected] of all classes. |
|
Return: |
|
float. |
|
""" |
|
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0 |
|
|
|
@property |
|
def map(self): |
|
"""Mean [email protected]:0.95 of all classes. |
|
Return: |
|
float. |
|
""" |
|
return self.all_ap.mean() if len(self.all_ap) else 0.0 |
|
|
|
def mean_results(self): |
|
"""Mean of results, return mp, mr, map50, map""" |
|
return [self.mp, self.mr, self.map50, self.map] |
|
|
|
def class_result(self, i): |
|
"""class-aware result, return p[i], r[i], ap50[i], ap[i]""" |
|
return self.p[i], self.r[i], self.ap50[i], self.ap[i] |
|
|
|
@property |
|
def maps(self): |
|
"""mAP of each class""" |
|
maps = np.zeros(self.nc) + self.map |
|
for i, c in enumerate(self.ap_class_index): |
|
maps[c] = self.ap[i] |
|
return maps |
|
|
|
def fitness(self): |
|
|
|
w = [0.0, 0.0, 0.1, 0.9] |
|
return (np.array(self.mean_results()) * w).sum() |
|
|
|
def update(self, results): |
|
""" |
|
Args: |
|
results: tuple(p, r, ap, f1, ap_class) |
|
""" |
|
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results |
|
|
|
|
|
class DetMetrics: |
|
|
|
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None: |
|
self.save_dir = save_dir |
|
self.plot = plot |
|
self.names = names |
|
self.box = Metric() |
|
|
|
def process(self, tp, conf, pred_cls, target_cls): |
|
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir, |
|
names=self.names)[2:] |
|
self.box.nc = len(self.names) |
|
self.box.update(results) |
|
|
|
@property |
|
def keys(self): |
|
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] |
|
|
|
def mean_results(self): |
|
return self.box.mean_results() |
|
|
|
def class_result(self, i): |
|
return self.box.class_result(i) |
|
|
|
@property |
|
def maps(self): |
|
return self.box.maps |
|
|
|
@property |
|
def fitness(self): |
|
return self.box.fitness() |
|
|
|
@property |
|
def ap_class_index(self): |
|
return self.box.ap_class_index |
|
|
|
@property |
|
def results_dict(self): |
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) |
|
|
|
|
|
def increment_path(path, exist_ok=False, sep='', mkdir=False): |
|
""" |
|
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. |
|
|
|
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to |
|
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the |
|
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a |
|
directory if it does not already exist. |
|
|
|
Args: |
|
path (str or pathlib.Path): Path to increment. |
|
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False. |
|
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string. |
|
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False. |
|
|
|
Returns: |
|
pathlib.Path: Incremented path. |
|
""" |
|
path = Path(path) |
|
if path.exists() and not exist_ok: |
|
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') |
|
|
|
|
|
for n in range(2, 9999): |
|
p = f'{path}{sep}{n}{suffix}' |
|
if not os.path.exists(p): |
|
break |
|
path = Path(p) |
|
|
|
if mkdir: |
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
return path |
|
|
|
|
|
def cfg2dict(cfg): |
|
""" |
|
Convert a configuration object to a dictionary. |
|
|
|
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object. |
|
|
|
Inputs: |
|
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary. |
|
|
|
Returns: |
|
cfg (dict): Configuration object in dictionary format. |
|
""" |
|
if isinstance(cfg, (str, Path)): |
|
cfg = yaml_load(cfg) |
|
elif isinstance(cfg, SimpleNamespace): |
|
cfg = vars(cfg) |
|
return cfg |
|
|
|
|
|
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = None, overrides: Dict = None): |
|
""" |
|
Load and merge configuration data from a file or dictionary. |
|
|
|
Args: |
|
cfg (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data. |
|
overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None. |
|
|
|
Returns: |
|
(SimpleNamespace): Training arguments namespace. |
|
""" |
|
cfg = cfg2dict(cfg) |
|
|
|
|
|
if overrides: |
|
overrides = cfg2dict(overrides) |
|
cfg = {**cfg, **overrides} |
|
|
|
|
|
for k in 'project', 'name': |
|
if k in cfg and isinstance(cfg[k], (int, float)): |
|
cfg[k] = str(cfg[k]) |
|
|
|
|
|
for k, v in cfg.items(): |
|
if v is not None: |
|
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): |
|
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") |
|
elif k in CFG_FRACTION_KEYS: |
|
if not isinstance(v, (int, float)): |
|
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')") |
|
if not (0.0 <= v <= 1.0): |
|
raise ValueError(f"'{k}={v}' is an invalid value. " |
|
f"Valid '{k}' values are between 0.0 and 1.0.") |
|
elif k in CFG_INT_KEYS and not isinstance(v, int): |
|
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
f"'{k}' must be an int (i.e. '{k}=0')") |
|
elif k in CFG_BOOL_KEYS and not isinstance(v, bool): |
|
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. " |
|
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')") |
|
|
|
|
|
return IterableSimpleNamespace(**cfg) |
|
|
|
|
|
def clip_boxes(boxes, shape): |
|
""" |
|
It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the |
|
shape |
|
|
|
Args: |
|
boxes (torch.Tensor): the bounding boxes to clip |
|
shape (tuple): the shape of the image |
|
""" |
|
if isinstance(boxes, torch.Tensor): |
|
boxes[..., 0].clamp_(0, shape[1]) |
|
boxes[..., 1].clamp_(0, shape[0]) |
|
boxes[..., 2].clamp_(0, shape[1]) |
|
boxes[..., 3].clamp_(0, shape[0]) |
|
else: |
|
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) |
|
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) |
|
|
|
|
|
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): |
|
""" |
|
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in |
|
(img1_shape) to the shape of a different image (img0_shape). |
|
|
|
Args: |
|
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). |
|
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) |
|
img0_shape (tuple): the shape of the target image, in the format of (height, width). |
|
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be |
|
calculated based on the size difference between the two images. |
|
|
|
Returns: |
|
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) |
|
""" |
|
if ratio_pad is None: |
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) |
|
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 |
|
else: |
|
gain = ratio_pad[0][0] |
|
pad = ratio_pad[1] |
|
|
|
boxes[..., [0, 2]] -= pad[0] |
|
boxes[..., [1, 3]] -= pad[1] |
|
boxes[..., :4] /= gain |
|
clip_boxes(boxes, img0_shape) |
|
return boxes |
|
|
|
|
|
def exif_size(img): |
|
|
|
s = img.size |
|
with contextlib.suppress(Exception): |
|
rotation = dict(img._getexif().items())[orientation] |
|
if rotation in [6, 8]: |
|
s = (s[1], s[0]) |
|
return s |
|
|
|
|
|
def verify_image_label(args): |
|
|
|
im_file, lb_file, prefix, keypoint, num_cls = args |
|
|
|
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None |
|
try: |
|
|
|
im = Image.open(im_file) |
|
im.verify() |
|
shape = exif_size(im) |
|
shape = (shape[1], shape[0]) |
|
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" |
|
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" |
|
if im.format.lower() in ("jpg", "jpeg"): |
|
with open(im_file, "rb") as f: |
|
f.seek(-2, 2) |
|
|
|
|
|
if os.path.isfile(lb_file): |
|
nf = 1 |
|
with open(lb_file) as f: |
|
lb = [x.split() for x in f.read().strip().splitlines() if len(x)] |
|
if any(len(x) > 6 for x in lb) and (not keypoint): |
|
classes = np.array([x[0] for x in lb], dtype=np.float32) |
|
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] |
|
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) |
|
lb = np.array(lb, dtype=np.float32) |
|
nl = len(lb) |
|
if nl: |
|
if keypoint: |
|
assert lb.shape[1] == 56, "labels require 56 columns each" |
|
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels" |
|
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels" |
|
kpts = np.zeros((lb.shape[0], 39)) |
|
for i in range(len(lb)): |
|
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) |
|
kpts[i] = np.hstack((lb[i, :5], kpt)) |
|
lb = kpts |
|
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter" |
|
else: |
|
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" |
|
assert (lb[:, 1:] <= 1).all(), \ |
|
f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}" |
|
|
|
max_cls = int(lb[:, 0].max()) |
|
assert max_cls <= num_cls, \ |
|
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \ |
|
f'Possible class labels are 0-{num_cls - 1}' |
|
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}" |
|
_, i = np.unique(lb, axis=0, return_index=True) |
|
if len(i) < nl: |
|
lb = lb[i] |
|
if segments: |
|
segments = [segments[x] for x in i] |
|
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" |
|
else: |
|
ne = 1 |
|
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32) |
|
else: |
|
nm = 1 |
|
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32) |
|
if keypoint: |
|
keypoints = lb[:, 5:].reshape(-1, 17, 2) |
|
lb = lb[:, :5] |
|
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg |
|
except Exception as e: |
|
nc = 1 |
|
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" |
|
return [None, None, None, None, None, nm, nf, ne, nc, msg] |