glenn-jocher
commited on
Commit
•
fcd180d
1
Parent(s):
7c6bae0
Refactor new `model.warmup()` method (#5810)
Browse files* Refactor new `model.warmup()` method
* Add half
- detect.py +1 -2
- models/common.py +7 -0
- val.py +1 -2
detect.py
CHANGED
@@ -97,8 +97,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
|
|
97 |
vid_path, vid_writer = [None] * bs, [None] * bs
|
98 |
|
99 |
# Run inference
|
100 |
-
|
101 |
-
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
|
102 |
dt, seen = [0.0, 0.0, 0.0], 0
|
103 |
for path, im, im0s, vid_cap, s in dataset:
|
104 |
t1 = time_sync()
|
|
|
97 |
vid_path, vid_writer = [None] * bs, [None] * bs
|
98 |
|
99 |
# Run inference
|
100 |
+
model.warmup(imgsz=(1, 3, *imgsz), half=half) # warmup
|
|
|
101 |
dt, seen = [0.0, 0.0, 0.0], 0
|
102 |
for path, im, im0s, vid_cap, s in dataset:
|
103 |
t1 = time_sync()
|
models/common.py
CHANGED
@@ -421,6 +421,13 @@ class DetectMultiBackend(nn.Module):
|
|
421 |
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
|
422 |
return (y, []) if val else y
|
423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
class AutoShape(nn.Module):
|
426 |
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
|
|
421 |
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
|
422 |
return (y, []) if val else y
|
423 |
|
424 |
+
def warmup(self, imgsz=(1, 3, 640, 640), half=False):
|
425 |
+
# Warmup model by running inference once
|
426 |
+
if self.pt or self.engine or self.onnx: # warmup types
|
427 |
+
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
|
428 |
+
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
|
429 |
+
self.forward(im) # warmup
|
430 |
+
|
431 |
|
432 |
class AutoShape(nn.Module):
|
433 |
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
val.py
CHANGED
@@ -149,8 +149,7 @@ def run(data,
|
|
149 |
|
150 |
# Dataloader
|
151 |
if not training:
|
152 |
-
|
153 |
-
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
|
154 |
pad = 0.0 if task == 'speed' else 0.5
|
155 |
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
156 |
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
|
|
|
149 |
|
150 |
# Dataloader
|
151 |
if not training:
|
152 |
+
model.warmup(imgsz=(1, 3, imgsz, imgsz), half=half) # warmup
|
|
|
153 |
pad = 0.0 if task == 'speed' else 0.5
|
154 |
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
|
155 |
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
|