glenn-jocher commited on
Commit
7997241
1 Parent(s): 7947c86

Update C3 module (#1705)

Browse files
Files changed (3) hide show
  1. models/common.py +16 -1
  2. models/experimental.py +0 -19
  3. models/yolo.py +2 -2
models/common.py CHANGED
@@ -29,7 +29,7 @@ class Conv(nn.Module):
29
  super(Conv, self).__init__()
30
  self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
31
  self.bn = nn.BatchNorm2d(c2)
32
- self.act = nn.Hardswish() if act else nn.Identity()
33
 
34
  def forward(self, x):
35
  return self.act(self.bn(self.conv(x)))
@@ -70,6 +70,21 @@ class BottleneckCSP(nn.Module):
70
  return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  class SPP(nn.Module):
74
  # Spatial pyramid pooling layer used in YOLOv3-SPP
75
  def __init__(self, c1, c2, k=(5, 9, 13)):
 
29
  super(Conv, self).__init__()
30
  self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
31
  self.bn = nn.BatchNorm2d(c2)
32
+ self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
33
 
34
  def forward(self, x):
35
  return self.act(self.bn(self.conv(x)))
 
70
  return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
71
 
72
 
73
+ class C3(nn.Module):
74
+ # CSP Bottleneck with 3 convolutions
75
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
76
+ super(C3, self).__init__()
77
+ c_ = int(c2 * e) # hidden channels
78
+ self.cv1 = Conv(c1, c_, 1, 1)
79
+ self.cv2 = Conv(c1, c_, 1, 1)
80
+ self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
81
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
82
+ # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
83
+
84
+ def forward(self, x):
85
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
86
+
87
+
88
  class SPP(nn.Module):
89
  # Spatial pyramid pooling layer used in YOLOv3-SPP
90
  def __init__(self, c1, c2, k=(5, 9, 13)):
models/experimental.py CHANGED
@@ -22,25 +22,6 @@ class CrossConv(nn.Module):
22
  return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
23
 
24
 
25
- class C3(nn.Module):
26
- # Cross Convolution CSP
27
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
28
- super(C3, self).__init__()
29
- c_ = int(c2 * e) # hidden channels
30
- self.cv1 = Conv(c1, c_, 1, 1)
31
- self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
32
- self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
33
- self.cv4 = Conv(2 * c_, c2, 1, 1)
34
- self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
35
- self.act = nn.LeakyReLU(0.1, inplace=True)
36
- self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
37
-
38
- def forward(self, x):
39
- y1 = self.cv3(self.m(self.cv1(x)))
40
- y2 = self.cv2(x)
41
- return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
42
-
43
-
44
  class Sum(nn.Module):
45
  # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
46
  def __init__(self, n, weight=False): # n: number of inputs
 
22
  return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class Sum(nn.Module):
26
  # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
27
  def __init__(self, n, weight=False): # n: number of inputs
models/yolo.py CHANGED
@@ -11,8 +11,8 @@ import torch.nn as nn
11
  sys.path.append('./') # to run '$ python *.py' files in subdirectories
12
  logger = logging.getLogger(__name__)
13
 
14
- from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
15
- from models.experimental import MixConv2d, CrossConv, C3
16
  from utils.autoanchor import check_anchor_order
17
  from utils.general import make_divisible, check_file, set_logging
18
  from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
 
11
  sys.path.append('./') # to run '$ python *.py' files in subdirectories
12
  logger = logging.getLogger(__name__)
13
 
14
+ from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C3, Concat, NMS, autoShape
15
+ from models.experimental import MixConv2d, CrossConv
16
  from utils.autoanchor import check_anchor_order
17
  from utils.general import make_divisible, check_file, set_logging
18
  from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \