glenn-jocher commited on
Commit
76ca367
1 Parent(s): 6fb5ff0

FP16 to FP32 ckpt load

Browse files
Files changed (2) hide show
  1. detect.py +1 -1
  2. test.py +1 -1
detect.py CHANGED
@@ -18,7 +18,7 @@ def detect(save_img=False):
18
 
19
  # Load model
20
  google_utils.attempt_download(weights)
21
- model = torch.load(weights, map_location=device)['model']
22
  # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
23
  # model.fuse()
24
  model.to(device).eval()
 
18
 
19
  # Load model
20
  google_utils.attempt_download(weights)
21
+ model = torch.load(weights, map_location=device)['model'].float() # load to FP32
22
  # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
23
  # model.fuse()
24
  model.to(device).eval()
test.py CHANGED
@@ -32,7 +32,7 @@ def test(data,
32
 
33
  # Load model
34
  google_utils.attempt_download(weights)
35
- model = torch.load(weights, map_location=device)['model']
36
  torch_utils.model_info(model)
37
  # model.fuse()
38
  model.to(device)
 
32
 
33
  # Load model
34
  google_utils.attempt_download(weights)
35
+ model = torch.load(weights, map_location=device)['model'].float() # load to FP32
36
  torch_utils.model_info(model)
37
  # model.fuse()
38
  model.to(device)