glenn-jocher
commited on
Commit
•
1d17b9a
1
Parent(s):
4b5f480
update yolo.py TTA flexibility and extensibility (#506)
Browse files* update yolo.py TTA flexibility and extensibility
* Update scale_img()
- models/yolo.py +13 -12
- utils/torch_utils.py +10 -7
models/yolo.py
CHANGED
@@ -82,18 +82,19 @@ class Model(nn.Module):
|
|
82 |
def forward(self, x, augment=False, profile=False):
|
83 |
if augment:
|
84 |
img_size = x.shape[-2:] # height, width
|
85 |
-
s = [0.83, 0.67] # scales
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
# cv2.imwrite('img%g.jpg' %
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
97 |
return torch.cat(y, 1), None # augmented inference, train
|
98 |
else:
|
99 |
return self.forward_once(x, profile) # single-scale inference, train
|
|
|
82 |
def forward(self, x, augment=False, profile=False):
|
83 |
if augment:
|
84 |
img_size = x.shape[-2:] # height, width
|
85 |
+
s = [1, 0.83, 0.67] # scales
|
86 |
+
f = [None, 3, None] # flips (2-ud, 3-lr)
|
87 |
+
y = [] # outputs
|
88 |
+
for si, fi in zip(s, f):
|
89 |
+
xi = torch_utils.scale_img(x.flip(fi) if fi else x, si)
|
90 |
+
yi = self.forward_once(xi)[0] # forward
|
91 |
+
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
92 |
+
yi[..., :4] /= si # de-scale
|
93 |
+
if fi is 2:
|
94 |
+
yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
|
95 |
+
elif fi is 3:
|
96 |
+
yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
|
97 |
+
y.append(yi)
|
98 |
return torch.cat(y, 1), None # augmented inference, train
|
99 |
else:
|
100 |
return self.forward_once(x, profile) # single-scale inference, train
|
utils/torch_utils.py
CHANGED
@@ -164,13 +164,16 @@ def load_classifier(name='resnet101', n=2):
|
|
164 |
|
165 |
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
|
166 |
# scales img(bs,3,y,x) by ratio
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
174 |
|
175 |
|
176 |
def copy_attr(a, b, include=(), exclude=()):
|
|
|
164 |
|
165 |
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
|
166 |
# scales img(bs,3,y,x) by ratio
|
167 |
+
if ratio == 1.0:
|
168 |
+
return img
|
169 |
+
else:
|
170 |
+
h, w = img.shape[2:]
|
171 |
+
s = (int(h * ratio), int(w * ratio)) # new size
|
172 |
+
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
173 |
+
if not same_shape: # pad/crop img
|
174 |
+
gs = 32 # (pixels) grid size
|
175 |
+
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
|
176 |
+
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
177 |
|
178 |
|
179 |
def copy_attr(a, b, include=(), exclude=()):
|