Commit
•
c949fc8
1
Parent(s):
c5c647e
Detection cropping+saving feature addition for detect.py and PyTorch Hub (#2827)
Browse files* Update detect.py
* Update detect.py
* Update greetings.yml
* Update cropping
* cleanup
* Update increment_path()
* Update common.py
* Update detect.py
* Update detect.py
* Update detect.py
* Update common.py
* cleanup
* Update detect.py
Co-authored-by: Glenn Jocher <[email protected]>
- detect.py +11 -7
- models/common.py +20 -12
- test.py +1 -1
- train.py +3 -3
- utils/general.py +21 -6
detect.py
CHANGED
@@ -10,19 +10,19 @@ from numpy import random
|
|
10 |
from models.experimental import attempt_load
|
11 |
from utils.datasets import LoadStreams, LoadImages
|
12 |
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
|
13 |
-
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
|
14 |
from utils.plots import plot_one_box
|
15 |
from utils.torch_utils import select_device, load_classifier, time_synchronized
|
16 |
|
17 |
|
18 |
-
def detect(
|
19 |
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
|
20 |
save_img = not opt.nosave and not source.endswith('.txt') # save inference images
|
21 |
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
|
22 |
('rtsp://', 'rtmp://', 'http://', 'https://'))
|
23 |
|
24 |
# Directories
|
25 |
-
save_dir =
|
26 |
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
27 |
|
28 |
# Initialize
|
@@ -84,7 +84,7 @@ def detect(save_img=False):
|
|
84 |
if webcam: # batch_size >= 1
|
85 |
p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
|
86 |
else:
|
87 |
-
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
|
88 |
|
89 |
p = Path(p) # to Path
|
90 |
save_path = str(save_dir / p.name) # img.jpg
|
@@ -108,9 +108,12 @@ def detect(save_img=False):
|
|
108 |
with open(txt_path + '.txt', 'a') as f:
|
109 |
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
110 |
|
111 |
-
if save_img or view_img: # Add bbox to image
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
114 |
|
115 |
# Print time (inference + NMS)
|
116 |
print(f'{s}Done. ({t2 - t1:.3f}s)')
|
@@ -157,6 +160,7 @@ if __name__ == '__main__':
|
|
157 |
parser.add_argument('--view-img', action='store_true', help='display results')
|
158 |
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
159 |
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
|
|
|
160 |
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
|
161 |
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
|
162 |
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
|
|
10 |
from models.experimental import attempt_load
|
11 |
from utils.datasets import LoadStreams, LoadImages
|
12 |
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
|
13 |
+
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
|
14 |
from utils.plots import plot_one_box
|
15 |
from utils.torch_utils import select_device, load_classifier, time_synchronized
|
16 |
|
17 |
|
18 |
+
def detect():
|
19 |
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
|
20 |
save_img = not opt.nosave and not source.endswith('.txt') # save inference images
|
21 |
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
|
22 |
('rtsp://', 'rtmp://', 'http://', 'https://'))
|
23 |
|
24 |
# Directories
|
25 |
+
save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
|
26 |
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
27 |
|
28 |
# Initialize
|
|
|
84 |
if webcam: # batch_size >= 1
|
85 |
p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
|
86 |
else:
|
87 |
+
p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
|
88 |
|
89 |
p = Path(p) # to Path
|
90 |
save_path = str(save_dir / p.name) # img.jpg
|
|
|
108 |
with open(txt_path + '.txt', 'a') as f:
|
109 |
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
110 |
|
111 |
+
if save_img or opt.save_crop or view_img: # Add bbox to image
|
112 |
+
c = int(cls) # integer class
|
113 |
+
label = f'{names[c]} {conf:.2f}'
|
114 |
+
plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=3)
|
115 |
+
if opt.save_crop:
|
116 |
+
save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
|
117 |
|
118 |
# Print time (inference + NMS)
|
119 |
print(f'{s}Done. ({t2 - t1:.3f}s)')
|
|
|
160 |
parser.add_argument('--view-img', action='store_true', help='display results')
|
161 |
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
162 |
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
|
163 |
+
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
|
164 |
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
|
165 |
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
|
166 |
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
models/common.py
CHANGED
@@ -13,7 +13,7 @@ from PIL import Image
|
|
13 |
from torch.cuda import amp
|
14 |
|
15 |
from utils.datasets import letterbox
|
16 |
-
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
|
17 |
from utils.plots import color_list, plot_one_box
|
18 |
from utils.torch_utils import time_synchronized
|
19 |
|
@@ -311,29 +311,33 @@ class Detections:
|
|
311 |
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
|
312 |
self.s = shape # inference BCHW shape
|
313 |
|
314 |
-
def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
|
315 |
colors = color_list()
|
316 |
-
for i, (
|
317 |
-
str = f'image {i + 1}/{len(self.pred)}: {
|
318 |
if pred is not None:
|
319 |
for c in pred[:, -1].unique():
|
320 |
n = (pred[:, -1] == c).sum() # detections per class
|
321 |
str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
322 |
-
if show or save or render:
|
323 |
for *box, conf, cls in pred: # xyxy, confidence, class
|
324 |
label = f'{self.names[int(cls)]} {conf:.2f}'
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
327 |
if pprint:
|
328 |
print(str.rstrip(', '))
|
329 |
if show:
|
330 |
-
|
331 |
if save:
|
332 |
f = self.files[i]
|
333 |
-
|
334 |
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
|
335 |
if render:
|
336 |
-
self.imgs[i] = np.asarray(
|
337 |
|
338 |
def print(self):
|
339 |
self.display(pprint=True) # print results
|
@@ -343,10 +347,14 @@ class Detections:
|
|
343 |
self.display(show=True) # show results
|
344 |
|
345 |
def save(self, save_dir='runs/hub/exp'):
|
346 |
-
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
|
347 |
-
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
348 |
self.display(save=True, save_dir=save_dir) # save results
|
349 |
|
|
|
|
|
|
|
|
|
|
|
350 |
def render(self):
|
351 |
self.display(render=True) # render results
|
352 |
return self.imgs
|
|
|
13 |
from torch.cuda import amp
|
14 |
|
15 |
from utils.datasets import letterbox
|
16 |
+
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
|
17 |
from utils.plots import color_list, plot_one_box
|
18 |
from utils.torch_utils import time_synchronized
|
19 |
|
|
|
311 |
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
|
312 |
self.s = shape # inference BCHW shape
|
313 |
|
314 |
+
def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
|
315 |
colors = color_list()
|
316 |
+
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
|
317 |
+
str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
|
318 |
if pred is not None:
|
319 |
for c in pred[:, -1].unique():
|
320 |
n = (pred[:, -1] == c).sum() # detections per class
|
321 |
str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
322 |
+
if show or save or render or crop:
|
323 |
for *box, conf, cls in pred: # xyxy, confidence, class
|
324 |
label = f'{self.names[int(cls)]} {conf:.2f}'
|
325 |
+
if crop:
|
326 |
+
save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
|
327 |
+
else: # all others
|
328 |
+
plot_one_box(box, im, label=label, color=colors[int(cls) % 10])
|
329 |
+
|
330 |
+
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
331 |
if pprint:
|
332 |
print(str.rstrip(', '))
|
333 |
if show:
|
334 |
+
im.show(self.files[i]) # show
|
335 |
if save:
|
336 |
f = self.files[i]
|
337 |
+
im.save(save_dir / f) # save
|
338 |
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
|
339 |
if render:
|
340 |
+
self.imgs[i] = np.asarray(im)
|
341 |
|
342 |
def print(self):
|
343 |
self.display(pprint=True) # print results
|
|
|
347 |
self.display(show=True) # show results
|
348 |
|
349 |
def save(self, save_dir='runs/hub/exp'):
|
350 |
+
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
|
|
|
351 |
self.display(save=True, save_dir=save_dir) # save results
|
352 |
|
353 |
+
def crop(self, save_dir='runs/hub/exp'):
|
354 |
+
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
|
355 |
+
self.display(crop=True, save_dir=save_dir) # crop results
|
356 |
+
print(f'Saved results to {save_dir}\n')
|
357 |
+
|
358 |
def render(self):
|
359 |
self.display(render=True) # render results
|
360 |
return self.imgs
|
test.py
CHANGED
@@ -49,7 +49,7 @@ def test(data,
|
|
49 |
device = select_device(opt.device, batch_size=batch_size)
|
50 |
|
51 |
# Directories
|
52 |
-
save_dir =
|
53 |
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
54 |
|
55 |
# Load model
|
|
|
49 |
device = select_device(opt.device, batch_size=batch_size)
|
50 |
|
51 |
# Directories
|
52 |
+
save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
|
53 |
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
|
54 |
|
55 |
# Load model
|
train.py
CHANGED
@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
|
|
41 |
def train(hyp, opt, device, tb_writer=None):
|
42 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
43 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
44 |
-
|
45 |
|
46 |
# Directories
|
47 |
wdir = save_dir / 'weights'
|
@@ -69,7 +69,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
69 |
if rank in [-1, 0]:
|
70 |
opt.hyp = hyp # add hyperparameters
|
71 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
72 |
-
wandb_logger = WandbLogger(opt,
|
73 |
loggers['wandb'] = wandb_logger.wandb
|
74 |
data_dict = wandb_logger.data_dict
|
75 |
if wandb_logger.wandb:
|
@@ -577,7 +577,7 @@ if __name__ == '__main__':
|
|
577 |
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
578 |
opt.notest, opt.nosave = True, True # only test/save final epoch
|
579 |
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
|
580 |
-
yaml_file =
|
581 |
if opt.bucket:
|
582 |
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
|
583 |
|
|
|
41 |
def train(hyp, opt, device, tb_writer=None):
|
42 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
43 |
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
44 |
+
opt.save_dir, opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
45 |
|
46 |
# Directories
|
47 |
wdir = save_dir / 'weights'
|
|
|
69 |
if rank in [-1, 0]:
|
70 |
opt.hyp = hyp # add hyperparameters
|
71 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
72 |
+
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
73 |
loggers['wandb'] = wandb_logger.wandb
|
74 |
data_dict = wandb_logger.data_dict
|
75 |
if wandb_logger.wandb:
|
|
|
577 |
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
578 |
opt.notest, opt.nosave = True, True # only test/save final epoch
|
579 |
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
|
580 |
+
yaml_file = opt.save_dir / 'hyp_evolved.yaml' # save best result here
|
581 |
if opt.bucket:
|
582 |
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
|
583 |
|
utils/general.py
CHANGED
@@ -557,7 +557,7 @@ def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
|
|
557 |
|
558 |
|
559 |
def apply_classifier(x, model, img, im0):
|
560 |
-
#
|
561 |
im0 = [im0] if isinstance(im0, np.ndarray) else im0
|
562 |
for i, d in enumerate(x): # per image
|
563 |
if d is not None and len(d):
|
@@ -591,16 +591,31 @@ def apply_classifier(x, model, img, im0):
|
|
591 |
return x
|
592 |
|
593 |
|
594 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
596 |
path = Path(path) # os-agnostic
|
597 |
-
if
|
598 |
-
return str(path)
|
599 |
-
else:
|
600 |
suffix = path.suffix
|
601 |
path = path.with_suffix('')
|
602 |
dirs = glob.glob(f"{path}{sep}*") # similar paths
|
603 |
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
|
604 |
i = [int(m.groups()[0]) for m in matches if m] # indices
|
605 |
n = max(i) + 1 if i else 2 # increment number
|
606 |
-
|
|
|
|
|
|
|
|
|
|
557 |
|
558 |
|
559 |
def apply_classifier(x, model, img, im0):
|
560 |
+
# Apply a second stage classifier to yolo outputs
|
561 |
im0 = [im0] if isinstance(im0, np.ndarray) else im0
|
562 |
for i, d in enumerate(x): # per image
|
563 |
if d is not None and len(d):
|
|
|
591 |
return x
|
592 |
|
593 |
|
594 |
+
def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False):
|
595 |
+
# Save an image crop as {file} with crop size multiplied by {gain} and padded by {pad} pixels
|
596 |
+
xyxy = torch.tensor(xyxy).view(-1, 4)
|
597 |
+
b = xyxy2xywh(xyxy) # boxes
|
598 |
+
if square:
|
599 |
+
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
600 |
+
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
601 |
+
xyxy = xywh2xyxy(b).long()
|
602 |
+
clip_coords(xyxy, im.shape)
|
603 |
+
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])]
|
604 |
+
cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop if BGR else crop[..., ::-1])
|
605 |
+
|
606 |
+
|
607 |
+
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
608 |
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
609 |
path = Path(path) # os-agnostic
|
610 |
+
if path.exists() and not exist_ok:
|
|
|
|
|
611 |
suffix = path.suffix
|
612 |
path = path.with_suffix('')
|
613 |
dirs = glob.glob(f"{path}{sep}*") # similar paths
|
614 |
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
|
615 |
i = [int(m.groups()[0]) for m in matches if m] # indices
|
616 |
n = max(i) + 1 if i else 2 # increment number
|
617 |
+
path = Path(f"{path}{sep}{n}{suffix}") # update path
|
618 |
+
dir = path if path.suffix == '' else path.parent # directory
|
619 |
+
if not dir.exists() and mkdir:
|
620 |
+
dir.mkdir(parents=True, exist_ok=True) # make directory
|
621 |
+
return path
|