glenn-jocher commited on
Commit
1d17b9a
1 Parent(s): 4b5f480

update yolo.py TTA flexibility and extensibility (#506)

Browse files

* update yolo.py TTA flexibility and extensibility

* Update scale_img()

Files changed (2) hide show
  1. models/yolo.py +13 -12
  2. utils/torch_utils.py +10 -7
models/yolo.py CHANGED
@@ -82,18 +82,19 @@ class Model(nn.Module):
82
  def forward(self, x, augment=False, profile=False):
83
  if augment:
84
  img_size = x.shape[-2:] # height, width
85
- s = [0.83, 0.67] # scales
86
- y = []
87
- for i, xi in enumerate((x,
88
- torch_utils.scale_img(x.flip(3), s[0]), # flip-lr and scale
89
- torch_utils.scale_img(x, s[1]), # scale
90
- )):
91
- # cv2.imwrite('img%g.jpg' % i, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])
92
- y.append(self.forward_once(xi)[0])
93
-
94
- y[1][..., :4] /= s[0] # scale
95
- y[1][..., 0] = img_size[1] - y[1][..., 0] # flip lr
96
- y[2][..., :4] /= s[1] # scale
 
97
  return torch.cat(y, 1), None # augmented inference, train
98
  else:
99
  return self.forward_once(x, profile) # single-scale inference, train
 
82
  def forward(self, x, augment=False, profile=False):
83
  if augment:
84
  img_size = x.shape[-2:] # height, width
85
+ s = [1, 0.83, 0.67] # scales
86
+ f = [None, 3, None] # flips (2-ud, 3-lr)
87
+ y = [] # outputs
88
+ for si, fi in zip(s, f):
89
+ xi = torch_utils.scale_img(x.flip(fi) if fi else x, si)
90
+ yi = self.forward_once(xi)[0] # forward
91
+ # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
92
+ yi[..., :4] /= si # de-scale
93
+ if fi is 2:
94
+ yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
95
+ elif fi is 3:
96
+ yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
97
+ y.append(yi)
98
  return torch.cat(y, 1), None # augmented inference, train
99
  else:
100
  return self.forward_once(x, profile) # single-scale inference, train
utils/torch_utils.py CHANGED
@@ -164,13 +164,16 @@ def load_classifier(name='resnet101', n=2):
164
 
165
  def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
166
  # scales img(bs,3,y,x) by ratio
167
- h, w = img.shape[2:]
168
- s = (int(h * ratio), int(w * ratio)) # new size
169
- img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
170
- if not same_shape: # pad/crop img
171
- gs = 32 # (pixels) grid size
172
- h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
173
- return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
 
 
 
174
 
175
 
176
  def copy_attr(a, b, include=(), exclude=()):
 
164
 
165
  def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
166
  # scales img(bs,3,y,x) by ratio
167
+ if ratio == 1.0:
168
+ return img
169
+ else:
170
+ h, w = img.shape[2:]
171
+ s = (int(h * ratio), int(w * ratio)) # new size
172
+ img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
173
+ if not same_shape: # pad/crop img
174
+ gs = 32 # (pixels) grid size
175
+ h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
176
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
177
 
178
 
179
  def copy_attr(a, b, include=(), exclude=()):