glenn-jocher commited on
Commit
5ea771d
1 Parent(s): 3213d87

Move IoU functions to metrics.py (#3820)

Browse files
Files changed (3) hide show
  1. utils/general.py +1 -79
  2. utils/loss.py +1 -1
  3. utils/metrics.py +80 -3
utils/general.py CHANGED
@@ -25,7 +25,7 @@ import torchvision
25
  import yaml
26
 
27
  from utils.google_utils import gsutil_getsize
28
- from utils.metrics import fitness
29
  from utils.torch_utils import init_torch_seeds
30
 
31
  # Settings
@@ -469,84 +469,6 @@ def clip_coords(boxes, img_shape):
469
  boxes[:, 3].clip(0, img_shape[0], out=boxes[:, 3]) # y2
470
 
471
 
472
- def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
473
- # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
474
- box2 = box2.T
475
-
476
- # Get the coordinates of bounding boxes
477
- if x1y1x2y2: # x1, y1, x2, y2 = box1
478
- b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
479
- b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
480
- else: # transform from xywh to xyxy
481
- b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
482
- b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
483
- b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
484
- b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
485
-
486
- # Intersection area
487
- inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
488
- (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
489
-
490
- # Union Area
491
- w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
492
- w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
493
- union = w1 * h1 + w2 * h2 - inter + eps
494
-
495
- iou = inter / union
496
- if GIoU or DIoU or CIoU:
497
- cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
498
- ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
499
- if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
500
- c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
501
- rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
502
- (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
503
- if DIoU:
504
- return iou - rho2 / c2 # DIoU
505
- elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
506
- v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
507
- with torch.no_grad():
508
- alpha = v / (v - iou + (1 + eps))
509
- return iou - (rho2 / c2 + v * alpha) # CIoU
510
- else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
511
- c_area = cw * ch + eps # convex area
512
- return iou - (c_area - union) / c_area # GIoU
513
- else:
514
- return iou # IoU
515
-
516
-
517
- def box_iou(box1, box2):
518
- # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
519
- """
520
- Return intersection-over-union (Jaccard index) of boxes.
521
- Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
522
- Arguments:
523
- box1 (Tensor[N, 4])
524
- box2 (Tensor[M, 4])
525
- Returns:
526
- iou (Tensor[N, M]): the NxM matrix containing the pairwise
527
- IoU values for every element in boxes1 and boxes2
528
- """
529
-
530
- def box_area(box):
531
- # box = 4xn
532
- return (box[2] - box[0]) * (box[3] - box[1])
533
-
534
- area1 = box_area(box1.T)
535
- area2 = box_area(box2.T)
536
-
537
- # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
538
- inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
539
- return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
540
-
541
-
542
- def wh_iou(wh1, wh2):
543
- # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
544
- wh1 = wh1[:, None] # [N,1,2]
545
- wh2 = wh2[None] # [1,M,2]
546
- inter = torch.min(wh1, wh2).prod(2) # [N,M]
547
- return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
548
-
549
-
550
  def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
551
  labels=(), max_det=300):
552
  """Runs Non-Maximum Suppression (NMS) on inference results
 
25
  import yaml
26
 
27
  from utils.google_utils import gsutil_getsize
28
+ from utils.metrics import box_iou, fitness
29
  from utils.torch_utils import init_torch_seeds
30
 
31
  # Settings
 
469
  boxes[:, 3].clip(0, img_shape[0], out=boxes[:, 3]) # y2
470
 
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
473
  labels=(), max_det=300):
474
  """Runs Non-Maximum Suppression (NMS) on inference results
utils/loss.py CHANGED
@@ -3,7 +3,7 @@
3
  import torch
4
  import torch.nn as nn
5
 
6
- from utils.general import bbox_iou
7
  from utils.torch_utils import is_parallel
8
 
9
 
 
3
  import torch
4
  import torch.nn as nn
5
 
6
+ from utils.metrics import bbox_iou
7
  from utils.torch_utils import is_parallel
8
 
9
 
utils/metrics.py CHANGED
@@ -1,5 +1,6 @@
1
  # Model validation metrics
2
 
 
3
  import warnings
4
  from pathlib import Path
5
 
@@ -7,8 +8,6 @@ import matplotlib.pyplot as plt
7
  import numpy as np
8
  import torch
9
 
10
- from . import general
11
-
12
 
13
  def fitness(x):
14
  # Model fitness as a weighted combination of metrics
@@ -128,7 +127,7 @@ class ConfusionMatrix:
128
  detections = detections[detections[:, 4] > self.conf]
129
  gt_classes = labels[:, 0].int()
130
  detection_classes = detections[:, 5].int()
131
- iou = general.box_iou(labels[:, 1:], detections[:, :4])
132
 
133
  x = torch.where(iou > self.iou_thres)
134
  if x[0].shape[0]:
@@ -184,6 +183,84 @@ class ConfusionMatrix:
184
  print(' '.join(map(str, self.matrix[i])))
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  # Plots ----------------------------------------------------------------------------------------------------------------
188
 
189
  def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
 
1
  # Model validation metrics
2
 
3
+ import math
4
  import warnings
5
  from pathlib import Path
6
 
 
8
  import numpy as np
9
  import torch
10
 
 
 
11
 
12
  def fitness(x):
13
  # Model fitness as a weighted combination of metrics
 
127
  detections = detections[detections[:, 4] > self.conf]
128
  gt_classes = labels[:, 0].int()
129
  detection_classes = detections[:, 5].int()
130
+ iou = box_iou(labels[:, 1:], detections[:, :4])
131
 
132
  x = torch.where(iou > self.iou_thres)
133
  if x[0].shape[0]:
 
183
  print(' '.join(map(str, self.matrix[i])))
184
 
185
 
186
+ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
187
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
188
+ box2 = box2.T
189
+
190
+ # Get the coordinates of bounding boxes
191
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
192
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
193
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
194
+ else: # transform from xywh to xyxy
195
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
196
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
197
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
198
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
199
+
200
+ # Intersection area
201
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
202
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
203
+
204
+ # Union Area
205
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
206
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
207
+ union = w1 * h1 + w2 * h2 - inter + eps
208
+
209
+ iou = inter / union
210
+ if GIoU or DIoU or CIoU:
211
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
212
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
213
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
214
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
215
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
216
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
217
+ if DIoU:
218
+ return iou - rho2 / c2 # DIoU
219
+ elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
220
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
221
+ with torch.no_grad():
222
+ alpha = v / (v - iou + (1 + eps))
223
+ return iou - (rho2 / c2 + v * alpha) # CIoU
224
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
225
+ c_area = cw * ch + eps # convex area
226
+ return iou - (c_area - union) / c_area # GIoU
227
+ else:
228
+ return iou # IoU
229
+
230
+
231
+ def box_iou(box1, box2):
232
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
233
+ """
234
+ Return intersection-over-union (Jaccard index) of boxes.
235
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
236
+ Arguments:
237
+ box1 (Tensor[N, 4])
238
+ box2 (Tensor[M, 4])
239
+ Returns:
240
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
241
+ IoU values for every element in boxes1 and boxes2
242
+ """
243
+
244
+ def box_area(box):
245
+ # box = 4xn
246
+ return (box[2] - box[0]) * (box[3] - box[1])
247
+
248
+ area1 = box_area(box1.T)
249
+ area2 = box_area(box2.T)
250
+
251
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
252
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
253
+ return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
254
+
255
+
256
+ def wh_iou(wh1, wh2):
257
+ # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
258
+ wh1 = wh1[:, None] # [N,1,2]
259
+ wh2 = wh2[None] # [1,M,2]
260
+ inter = torch.min(wh1, wh2).prod(2) # [N,M]
261
+ return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
262
+
263
+
264
  # Plots ----------------------------------------------------------------------------------------------------------------
265
 
266
  def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):