glenn-jocher
commited on
Commit
•
76ca367
1
Parent(s):
6fb5ff0
FP16 to FP32 ckpt load
Browse files
detect.py
CHANGED
@@ -18,7 +18,7 @@ def detect(save_img=False):
|
|
18 |
|
19 |
# Load model
|
20 |
google_utils.attempt_download(weights)
|
21 |
-
model = torch.load(weights, map_location=device)['model']
|
22 |
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
|
23 |
# model.fuse()
|
24 |
model.to(device).eval()
|
|
|
18 |
|
19 |
# Load model
|
20 |
google_utils.attempt_download(weights)
|
21 |
+
model = torch.load(weights, map_location=device)['model'].float() # load to FP32
|
22 |
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
|
23 |
# model.fuse()
|
24 |
model.to(device).eval()
|
test.py
CHANGED
@@ -32,7 +32,7 @@ def test(data,
|
|
32 |
|
33 |
# Load model
|
34 |
google_utils.attempt_download(weights)
|
35 |
-
model = torch.load(weights, map_location=device)['model']
|
36 |
torch_utils.model_info(model)
|
37 |
# model.fuse()
|
38 |
model.to(device)
|
|
|
32 |
|
33 |
# Load model
|
34 |
google_utils.attempt_download(weights)
|
35 |
+
model = torch.load(weights, map_location=device)['model'].float() # load to FP32
|
36 |
torch_utils.model_info(model)
|
37 |
# model.fuse()
|
38 |
model.to(device)
|