glenn-jocher
commited on
Commit
•
dd28df9
1
Parent(s):
9d7bc06
Avoid FP64 ops for MPS support in train.py (#8511)
Browse filesAvoid FP64 ops for MPS support
Resolves https://github.com/ultralytics/yolov5/pull/7878#issuecomment-1177952614
- utils/general.py +3 -3
utils/general.py
CHANGED
@@ -644,7 +644,7 @@ def labels_to_class_weights(labels, nc=80):
|
|
644 |
return torch.Tensor()
|
645 |
|
646 |
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
647 |
-
classes = labels[:, 0].astype(
|
648 |
weights = np.bincount(classes, minlength=nc) # occurrences per class
|
649 |
|
650 |
# Prepend gridpoint count (for uCE training)
|
@@ -654,13 +654,13 @@ def labels_to_class_weights(labels, nc=80):
|
|
654 |
weights[weights == 0] = 1 # replace empty bins with 1
|
655 |
weights = 1 / weights # number of targets per class
|
656 |
weights /= weights.sum() # normalize
|
657 |
-
return torch.from_numpy(weights)
|
658 |
|
659 |
|
660 |
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
661 |
# Produces image weights based on class_weights and image contents
|
662 |
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
|
663 |
-
class_counts = np.array([np.bincount(x[:, 0].astype(
|
664 |
return (class_weights.reshape(1, nc) * class_counts).sum(1)
|
665 |
|
666 |
|
|
|
644 |
return torch.Tensor()
|
645 |
|
646 |
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
647 |
+
classes = labels[:, 0].astype(int) # labels = [class xywh]
|
648 |
weights = np.bincount(classes, minlength=nc) # occurrences per class
|
649 |
|
650 |
# Prepend gridpoint count (for uCE training)
|
|
|
654 |
weights[weights == 0] = 1 # replace empty bins with 1
|
655 |
weights = 1 / weights # number of targets per class
|
656 |
weights /= weights.sum() # normalize
|
657 |
+
return torch.from_numpy(weights).float()
|
658 |
|
659 |
|
660 |
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
661 |
# Produces image weights based on class_weights and image contents
|
662 |
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
|
663 |
+
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
|
664 |
return (class_weights.reshape(1, nc) * class_counts).sum(1)
|
665 |
|
666 |
|