glenn-jocher commited on
Commit
d3d9cbc
1 Parent(s): 6dd82c0

PyTorch 1.11.0 compatibility updates (#6932)

Browse files

Resolves `AttributeError: 'Upsample' object has no attribute 'recompute_scale_factor'` first raised in https://github.com/ultralytics/yolov5/issues/5499

Files changed (1) hide show
  1. models/experimental.py +10 -9
models/experimental.py CHANGED
@@ -94,21 +94,22 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
94
  model = Ensemble()
95
  for w in weights if isinstance(weights, list) else [weights]:
96
  ckpt = torch.load(attempt_download(w), map_location=map_location) # load
97
- if fuse:
98
- model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
99
- else:
100
- model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
101
 
102
  # Compatibility updates
103
  for m in model.modules():
104
- if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
105
- m.inplace = inplace # pytorch 1.7.0 compatibility
106
- if type(m) is Detect:
 
107
  if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
108
  delattr(m, 'anchor_grid')
109
  setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
110
- elif type(m) is Conv:
111
- m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
 
 
112
 
113
  if len(model) == 1:
114
  return model[-1] # return model
 
94
  model = Ensemble()
95
  for w in weights if isinstance(weights, list) else [weights]:
96
  ckpt = torch.load(attempt_download(w), map_location=map_location) # load
97
+ ckpt = (ckpt['ema'] or ckpt['model']).float() # FP32 model
98
+ model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
 
 
99
 
100
  # Compatibility updates
101
  for m in model.modules():
102
+ t = type(m)
103
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
104
+ m.inplace = inplace # torch 1.7.0 compatibility
105
+ if t is Detect:
106
  if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
107
  delattr(m, 'anchor_grid')
108
  setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
109
+ elif t is nn.Upsample:
110
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
111
+ elif t is Conv:
112
+ m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
113
 
114
  if len(model) == 1:
115
  return model[-1] # return model