BeMerciless commited on
Commit
17b7b6b
1 Parent(s): 558cce1

Upload 4 files

Browse files
Files changed (4) hide show
  1. models/__init__.py +1 -0
  2. models/common.py +2047 -0
  3. models/experimental.py +262 -0
  4. models/yolo.py +953 -0
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
models/common.py ADDED
@@ -0,0 +1,2047 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from copy import copy
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import requests
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchvision.ops import DeformConv2d
12
+ from PIL import Image
13
+ from torch.cuda import amp
14
+
15
+ from utils.datasets import letterbox
16
+ from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
17
+ from utils.plots import color_list, plot_one_box
18
+ from utils.torch_utils import time_synchronized
19
+
20
+
21
+ ##### basic ####
22
+
23
+ def autopad(k, p=None): # kernel, padding
24
+ # Pad to 'same'
25
+ if p is None:
26
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
27
+ return p
28
+
29
+
30
+ class MP(nn.Module):
31
+ def __init__(self, k=2):
32
+ super(MP, self).__init__()
33
+ self.m = nn.MaxPool2d(kernel_size=k, stride=k)
34
+
35
+ def forward(self, x):
36
+ return self.m(x)
37
+
38
+
39
+ class SP(nn.Module):
40
+ def __init__(self, k=3, s=1):
41
+ super(SP, self).__init__()
42
+ self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
43
+
44
+ def forward(self, x):
45
+ return self.m(x)
46
+
47
+
48
+ class ReOrg(nn.Module):
49
+ def __init__(self):
50
+ super(ReOrg, self).__init__()
51
+
52
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
53
+ return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
54
+
55
+
56
+ class Merge(nn.Module):
57
+ def __init__(self,ch=()):
58
+ super(Merge, self).__init__()
59
+
60
+ def forward(self, x):
61
+
62
+ return [x[0],x[1],x[2]]
63
+
64
+
65
+ class Refine(nn.Module):
66
+
67
+ def __init__(self, c2, k, s, ch): # ch_in, ch_out, kernel, stride, padding, groups
68
+ super(Refine, self).__init__()
69
+ self.refine = nn.ModuleList()
70
+ for c in ch:
71
+ self.refine.append(Conv(c, c2, k, s))
72
+
73
+ def forward(self, x):
74
+ for i, f in enumerate(x):
75
+ if i == 0:
76
+ r = self.refine[i](f)
77
+ else:
78
+ r_p = self.refine[i](f)
79
+ r_p = F.interpolate(r_p, r.size()[2:], mode="bilinear", align_corners=False)
80
+ r = r + r_p
81
+ return r
82
+
83
+
84
+ class Concat(nn.Module):
85
+ def __init__(self, dimension=1):
86
+ super(Concat, self).__init__()
87
+ self.d = dimension
88
+
89
+ def forward(self, x):
90
+ return torch.cat(x, self.d)
91
+
92
+
93
+ class Chuncat(nn.Module):
94
+ def __init__(self, dimension=1):
95
+ super(Chuncat, self).__init__()
96
+ self.d = dimension
97
+
98
+ def forward(self, x):
99
+ x1 = []
100
+ x2 = []
101
+ for xi in x:
102
+ xi1, xi2 = xi.chunk(2, self.d)
103
+ x1.append(xi1)
104
+ x2.append(xi2)
105
+ return torch.cat(x1+x2, self.d)
106
+
107
+
108
+ class Shortcut(nn.Module):
109
+ def __init__(self, dimension=0):
110
+ super(Shortcut, self).__init__()
111
+ self.d = dimension
112
+
113
+ def forward(self, x):
114
+ return x[0]+x[1]
115
+
116
+
117
+ class Foldcut(nn.Module):
118
+ def __init__(self, dimension=0):
119
+ super(Foldcut, self).__init__()
120
+ self.d = dimension
121
+
122
+ def forward(self, x):
123
+ x1, x2 = x.chunk(2, self.d)
124
+ return x1+x2
125
+
126
+
127
+ class Conv(nn.Module):
128
+ # Standard convolution
129
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
130
+ super(Conv, self).__init__()
131
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
132
+ self.bn = nn.BatchNorm2d(c2)
133
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
134
+
135
+ def forward(self, x):
136
+ return self.act(self.bn(self.conv(x)))
137
+
138
+ def fuseforward(self, x):
139
+ return self.act(self.conv(x))
140
+
141
+
142
+ class RobustConv(nn.Module):
143
+ # Robust convolution (use high kernel size 7-11 for: downsampling and other layers). Train for 300 - 450 epochs.
144
+ def __init__(self, c1, c2, k=7, s=1, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
145
+ super(RobustConv, self).__init__()
146
+ self.conv_dw = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
147
+ self.conv1x1 = nn.Conv2d(c1, c2, 1, 1, 0, groups=1, bias=True)
148
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
149
+
150
+ def forward(self, x):
151
+ x = x.to(memory_format=torch.channels_last)
152
+ x = self.conv1x1(self.conv_dw(x))
153
+ if self.gamma is not None:
154
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
155
+ return x
156
+
157
+
158
+ class RobustConv2(nn.Module):
159
+ # Robust convolution 2 (use [32, 5, 2] or [32, 7, 4] or [32, 11, 8] for one of the paths in CSP).
160
+ def __init__(self, c1, c2, k=7, s=4, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
161
+ super(RobustConv2, self).__init__()
162
+ self.conv_strided = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
163
+ self.conv_deconv = nn.ConvTranspose2d(in_channels=c1, out_channels=c2, kernel_size=s, stride=s,
164
+ padding=0, bias=True, dilation=1, groups=1
165
+ )
166
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
167
+
168
+ def forward(self, x):
169
+ x = self.conv_deconv(self.conv_strided(x))
170
+ if self.gamma is not None:
171
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
172
+ return x
173
+
174
+
175
+ def DWConv(c1, c2, k=1, s=1, act=True):
176
+ # Depthwise convolution
177
+ return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
178
+
179
+
180
+ class GhostConv(nn.Module):
181
+ # Ghost Convolution https://github.com/huawei-noah/ghostnet
182
+ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
183
+ super(GhostConv, self).__init__()
184
+ c_ = c2 // 2 # hidden channels
185
+ self.cv1 = Conv(c1, c_, k, s, None, g, act)
186
+ self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
187
+
188
+ def forward(self, x):
189
+ y = self.cv1(x)
190
+ return torch.cat([y, self.cv2(y)], 1)
191
+
192
+
193
+ class Stem(nn.Module):
194
+ # Stem
195
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
196
+ super(Stem, self).__init__()
197
+ c_ = int(c2/2) # hidden channels
198
+ self.cv1 = Conv(c1, c_, 3, 2)
199
+ self.cv2 = Conv(c_, c_, 1, 1)
200
+ self.cv3 = Conv(c_, c_, 3, 2)
201
+ self.pool = torch.nn.MaxPool2d(2, stride=2)
202
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
203
+
204
+ def forward(self, x):
205
+ x = self.cv1(x)
206
+ return self.cv4(torch.cat((self.cv3(self.cv2(x)), self.pool(x)), dim=1))
207
+
208
+
209
+ class DownC(nn.Module):
210
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
211
+ def __init__(self, c1, c2, n=1, k=2):
212
+ super(DownC, self).__init__()
213
+ c_ = int(c1) # hidden channels
214
+ self.cv1 = Conv(c1, c_, 1, 1)
215
+ self.cv2 = Conv(c_, c2//2, 3, k)
216
+ self.cv3 = Conv(c1, c2//2, 1, 1)
217
+ self.mp = nn.MaxPool2d(kernel_size=k, stride=k)
218
+
219
+ def forward(self, x):
220
+ return torch.cat((self.cv2(self.cv1(x)), self.cv3(self.mp(x))), dim=1)
221
+
222
+
223
+ class SPP(nn.Module):
224
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
225
+ def __init__(self, c1, c2, k=(5, 9, 13)):
226
+ super(SPP, self).__init__()
227
+ c_ = c1 // 2 # hidden channels
228
+ self.cv1 = Conv(c1, c_, 1, 1)
229
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
230
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
231
+
232
+ def forward(self, x):
233
+ x = self.cv1(x)
234
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
235
+
236
+
237
+ class Bottleneck(nn.Module):
238
+ # Darknet bottleneck
239
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
240
+ super(Bottleneck, self).__init__()
241
+ c_ = int(c2 * e) # hidden channels
242
+ self.cv1 = Conv(c1, c_, 1, 1)
243
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
244
+ self.add = shortcut and c1 == c2
245
+
246
+ def forward(self, x):
247
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
248
+
249
+
250
+ class Res(nn.Module):
251
+ # ResNet bottleneck
252
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
253
+ super(Res, self).__init__()
254
+ c_ = int(c2 * e) # hidden channels
255
+ self.cv1 = Conv(c1, c_, 1, 1)
256
+ self.cv2 = Conv(c_, c_, 3, 1, g=g)
257
+ self.cv3 = Conv(c_, c2, 1, 1)
258
+ self.add = shortcut and c1 == c2
259
+
260
+ def forward(self, x):
261
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
262
+
263
+
264
+ class ResX(Res):
265
+ # ResNet bottleneck
266
+ def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
267
+ super().__init__(c1, c2, shortcut, g, e)
268
+ c_ = int(c2 * e) # hidden channels
269
+
270
+
271
+ class Ghost(nn.Module):
272
+ # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
273
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
274
+ super(Ghost, self).__init__()
275
+ c_ = c2 // 2
276
+ self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
277
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
278
+ GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
279
+ self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
280
+ Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
281
+
282
+ def forward(self, x):
283
+ return self.conv(x) + self.shortcut(x)
284
+
285
+ ##### end of basic #####
286
+
287
+
288
+ ##### cspnet #####
289
+
290
+ class SPPCSPC(nn.Module):
291
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
292
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
293
+ super(SPPCSPC, self).__init__()
294
+ c_ = int(2 * c2 * e) # hidden channels
295
+ self.cv1 = Conv(c1, c_, 1, 1)
296
+ self.cv2 = Conv(c1, c_, 1, 1)
297
+ self.cv3 = Conv(c_, c_, 3, 1)
298
+ self.cv4 = Conv(c_, c_, 1, 1)
299
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
300
+ self.cv5 = Conv(4 * c_, c_, 1, 1)
301
+ self.cv6 = Conv(c_, c_, 3, 1)
302
+ self.cv7 = Conv(2 * c_, c2, 1, 1)
303
+
304
+ def forward(self, x):
305
+ x1 = self.cv4(self.cv3(self.cv1(x)))
306
+ y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
307
+ y2 = self.cv2(x)
308
+ return self.cv7(torch.cat((y1, y2), dim=1))
309
+
310
+ class GhostSPPCSPC(SPPCSPC):
311
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
312
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
313
+ super().__init__(c1, c2, n, shortcut, g, e, k)
314
+ c_ = int(2 * c2 * e) # hidden channels
315
+ self.cv1 = GhostConv(c1, c_, 1, 1)
316
+ self.cv2 = GhostConv(c1, c_, 1, 1)
317
+ self.cv3 = GhostConv(c_, c_, 3, 1)
318
+ self.cv4 = GhostConv(c_, c_, 1, 1)
319
+ self.cv5 = GhostConv(4 * c_, c_, 1, 1)
320
+ self.cv6 = GhostConv(c_, c_, 3, 1)
321
+ self.cv7 = GhostConv(2 * c_, c2, 1, 1)
322
+
323
+
324
+ class GhostStem(Stem):
325
+ # Stem
326
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
327
+ super().__init__(c1, c2, k, s, p, g, act)
328
+ c_ = int(c2/2) # hidden channels
329
+ self.cv1 = GhostConv(c1, c_, 3, 2)
330
+ self.cv2 = GhostConv(c_, c_, 1, 1)
331
+ self.cv3 = GhostConv(c_, c_, 3, 2)
332
+ self.cv4 = GhostConv(2 * c_, c2, 1, 1)
333
+
334
+
335
+ class BottleneckCSPA(nn.Module):
336
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
337
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
338
+ super(BottleneckCSPA, self).__init__()
339
+ c_ = int(c2 * e) # hidden channels
340
+ self.cv1 = Conv(c1, c_, 1, 1)
341
+ self.cv2 = Conv(c1, c_, 1, 1)
342
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
343
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
344
+
345
+ def forward(self, x):
346
+ y1 = self.m(self.cv1(x))
347
+ y2 = self.cv2(x)
348
+ return self.cv3(torch.cat((y1, y2), dim=1))
349
+
350
+
351
+ class BottleneckCSPB(nn.Module):
352
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
353
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
354
+ super(BottleneckCSPB, self).__init__()
355
+ c_ = int(c2) # hidden channels
356
+ self.cv1 = Conv(c1, c_, 1, 1)
357
+ self.cv2 = Conv(c_, c_, 1, 1)
358
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
359
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
360
+
361
+ def forward(self, x):
362
+ x1 = self.cv1(x)
363
+ y1 = self.m(x1)
364
+ y2 = self.cv2(x1)
365
+ return self.cv3(torch.cat((y1, y2), dim=1))
366
+
367
+
368
+ class BottleneckCSPC(nn.Module):
369
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
370
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
371
+ super(BottleneckCSPC, self).__init__()
372
+ c_ = int(c2 * e) # hidden channels
373
+ self.cv1 = Conv(c1, c_, 1, 1)
374
+ self.cv2 = Conv(c1, c_, 1, 1)
375
+ self.cv3 = Conv(c_, c_, 1, 1)
376
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
377
+ self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
378
+
379
+ def forward(self, x):
380
+ y1 = self.cv3(self.m(self.cv1(x)))
381
+ y2 = self.cv2(x)
382
+ return self.cv4(torch.cat((y1, y2), dim=1))
383
+
384
+
385
+ class ResCSPA(BottleneckCSPA):
386
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
387
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
388
+ super().__init__(c1, c2, n, shortcut, g, e)
389
+ c_ = int(c2 * e) # hidden channels
390
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
391
+
392
+
393
+ class ResCSPB(BottleneckCSPB):
394
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
395
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
396
+ super().__init__(c1, c2, n, shortcut, g, e)
397
+ c_ = int(c2) # hidden channels
398
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
399
+
400
+
401
+ class ResCSPC(BottleneckCSPC):
402
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
403
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
404
+ super().__init__(c1, c2, n, shortcut, g, e)
405
+ c_ = int(c2 * e) # hidden channels
406
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
407
+
408
+
409
+ class ResXCSPA(ResCSPA):
410
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
411
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
412
+ super().__init__(c1, c2, n, shortcut, g, e)
413
+ c_ = int(c2 * e) # hidden channels
414
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
415
+
416
+
417
+ class ResXCSPB(ResCSPB):
418
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
419
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
420
+ super().__init__(c1, c2, n, shortcut, g, e)
421
+ c_ = int(c2) # hidden channels
422
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
423
+
424
+
425
+ class ResXCSPC(ResCSPC):
426
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
427
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
428
+ super().__init__(c1, c2, n, shortcut, g, e)
429
+ c_ = int(c2 * e) # hidden channels
430
+ self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
431
+
432
+
433
+ class GhostCSPA(BottleneckCSPA):
434
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
435
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
436
+ super().__init__(c1, c2, n, shortcut, g, e)
437
+ c_ = int(c2 * e) # hidden channels
438
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
439
+
440
+
441
+ class GhostCSPB(BottleneckCSPB):
442
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
443
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
444
+ super().__init__(c1, c2, n, shortcut, g, e)
445
+ c_ = int(c2) # hidden channels
446
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
447
+
448
+
449
+ class GhostCSPC(BottleneckCSPC):
450
+ # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
451
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
452
+ super().__init__(c1, c2, n, shortcut, g, e)
453
+ c_ = int(c2 * e) # hidden channels
454
+ self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
455
+
456
+ ##### end of cspnet #####
457
+
458
+
459
+ ##### yolor #####
460
+
461
+ class ImplicitA(nn.Module):
462
+ def __init__(self, channel, mean=0., std=.02):
463
+ super(ImplicitA, self).__init__()
464
+ self.channel = channel
465
+ self.mean = mean
466
+ self.std = std
467
+ self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
468
+ nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
469
+
470
+ def forward(self, x):
471
+ return self.implicit + x
472
+
473
+
474
+ class ImplicitM(nn.Module):
475
+ def __init__(self, channel, mean=0., std=.02):
476
+ super(ImplicitM, self).__init__()
477
+ self.channel = channel
478
+ self.mean = mean
479
+ self.std = std
480
+ self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
481
+ nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
482
+
483
+ def forward(self, x):
484
+ return self.implicit * x
485
+
486
+ ##### end of yolor #####
487
+
488
+
489
+ ##### repvgg #####
490
+
491
+ class RepConv(nn.Module):
492
+ # Represented convolution
493
+ # https://arxiv.org/abs/2101.03697
494
+
495
+ def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, deploy=False):
496
+ super(RepConv, self).__init__()
497
+
498
+ self.deploy = deploy
499
+ self.groups = g
500
+ self.in_channels = c1
501
+ self.out_channels = c2
502
+
503
+ assert k == 3
504
+ assert autopad(k, p) == 1
505
+
506
+ padding_11 = autopad(k, p) - k // 2
507
+
508
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
509
+
510
+ if deploy:
511
+ self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True)
512
+
513
+ else:
514
+ self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
515
+
516
+ self.rbr_dense = nn.Sequential(
517
+ nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False),
518
+ nn.BatchNorm2d(num_features=c2),
519
+ )
520
+
521
+ self.rbr_1x1 = nn.Sequential(
522
+ nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False),
523
+ nn.BatchNorm2d(num_features=c2),
524
+ )
525
+
526
+ def forward(self, inputs):
527
+ if hasattr(self, "rbr_reparam"):
528
+ return self.act(self.rbr_reparam(inputs))
529
+
530
+ if self.rbr_identity is None:
531
+ id_out = 0
532
+ else:
533
+ id_out = self.rbr_identity(inputs)
534
+
535
+ return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
536
+
537
+ def get_equivalent_kernel_bias(self):
538
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
539
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
540
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
541
+ return (
542
+ kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
543
+ bias3x3 + bias1x1 + biasid,
544
+ )
545
+
546
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
547
+ if kernel1x1 is None:
548
+ return 0
549
+ else:
550
+ return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
551
+
552
+ def _fuse_bn_tensor(self, branch):
553
+ if branch is None:
554
+ return 0, 0
555
+ if isinstance(branch, nn.Sequential):
556
+ kernel = branch[0].weight
557
+ running_mean = branch[1].running_mean
558
+ running_var = branch[1].running_var
559
+ gamma = branch[1].weight
560
+ beta = branch[1].bias
561
+ eps = branch[1].eps
562
+ else:
563
+ assert isinstance(branch, nn.BatchNorm2d)
564
+ if not hasattr(self, "id_tensor"):
565
+ input_dim = self.in_channels // self.groups
566
+ kernel_value = np.zeros(
567
+ (self.in_channels, input_dim, 3, 3), dtype=np.float32
568
+ )
569
+ for i in range(self.in_channels):
570
+ kernel_value[i, i % input_dim, 1, 1] = 1
571
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
572
+ kernel = self.id_tensor
573
+ running_mean = branch.running_mean
574
+ running_var = branch.running_var
575
+ gamma = branch.weight
576
+ beta = branch.bias
577
+ eps = branch.eps
578
+ std = (running_var + eps).sqrt()
579
+ t = (gamma / std).reshape(-1, 1, 1, 1)
580
+ return kernel * t, beta - running_mean * gamma / std
581
+
582
+ def repvgg_convert(self):
583
+ kernel, bias = self.get_equivalent_kernel_bias()
584
+ return (
585
+ kernel.detach().cpu().numpy(),
586
+ bias.detach().cpu().numpy(),
587
+ )
588
+
589
+ def fuse_conv_bn(self, conv, bn):
590
+
591
+ std = (bn.running_var + bn.eps).sqrt()
592
+ bias = bn.bias - bn.running_mean * bn.weight / std
593
+
594
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
595
+ weights = conv.weight * t
596
+
597
+ bn = nn.Identity()
598
+ conv = nn.Conv2d(in_channels = conv.in_channels,
599
+ out_channels = conv.out_channels,
600
+ kernel_size = conv.kernel_size,
601
+ stride=conv.stride,
602
+ padding = conv.padding,
603
+ dilation = conv.dilation,
604
+ groups = conv.groups,
605
+ bias = True,
606
+ padding_mode = conv.padding_mode)
607
+
608
+ conv.weight = torch.nn.Parameter(weights)
609
+ conv.bias = torch.nn.Parameter(bias)
610
+ return conv
611
+
612
+ def fuse_repvgg_block(self):
613
+ if self.deploy:
614
+ return
615
+ print(f"RepConv.fuse_repvgg_block")
616
+
617
+ self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
618
+
619
+ self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
620
+ rbr_1x1_bias = self.rbr_1x1.bias
621
+ weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
622
+
623
+ # Fuse self.rbr_identity
624
+ if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
625
+ # print(f"fuse: rbr_identity == BatchNorm2d or SyncBatchNorm")
626
+ identity_conv_1x1 = nn.Conv2d(
627
+ in_channels=self.in_channels,
628
+ out_channels=self.out_channels,
629
+ kernel_size=1,
630
+ stride=1,
631
+ padding=0,
632
+ groups=self.groups,
633
+ bias=False)
634
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
635
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
636
+ # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
637
+ identity_conv_1x1.weight.data.fill_(0.0)
638
+ identity_conv_1x1.weight.data.fill_diagonal_(1.0)
639
+ identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
640
+ # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
641
+
642
+ identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
643
+ bias_identity_expanded = identity_conv_1x1.bias
644
+ weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])
645
+ else:
646
+ # print(f"fuse: rbr_identity != BatchNorm2d, rbr_identity = {self.rbr_identity}")
647
+ bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
648
+ weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )
649
+
650
+
651
+ #print(f"self.rbr_1x1.weight = {self.rbr_1x1.weight.shape}, ")
652
+ #print(f"weight_1x1_expanded = {weight_1x1_expanded.shape}, ")
653
+ #print(f"self.rbr_dense.weight = {self.rbr_dense.weight.shape}, ")
654
+
655
+ self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
656
+ self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
657
+
658
+ self.rbr_reparam = self.rbr_dense
659
+ self.deploy = True
660
+
661
+ if self.rbr_identity is not None:
662
+ del self.rbr_identity
663
+ self.rbr_identity = None
664
+
665
+ if self.rbr_1x1 is not None:
666
+ del self.rbr_1x1
667
+ self.rbr_1x1 = None
668
+
669
+ if self.rbr_dense is not None:
670
+ del self.rbr_dense
671
+ self.rbr_dense = None
672
+
673
+
674
+ class RepBottleneck(Bottleneck):
675
+ # Standard bottleneck
676
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
677
+ super().__init__(c1, c2, shortcut=True, g=1, e=0.5)
678
+ c_ = int(c2 * e) # hidden channels
679
+ self.cv2 = RepConv(c_, c2, 3, 1, g=g)
680
+
681
+
682
+ class RepBottleneckCSPA(BottleneckCSPA):
683
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
684
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
685
+ super().__init__(c1, c2, n, shortcut, g, e)
686
+ c_ = int(c2 * e) # hidden channels
687
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
688
+
689
+
690
+ class RepBottleneckCSPB(BottleneckCSPB):
691
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
692
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
693
+ super().__init__(c1, c2, n, shortcut, g, e)
694
+ c_ = int(c2) # hidden channels
695
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
696
+
697
+
698
+ class RepBottleneckCSPC(BottleneckCSPC):
699
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
700
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
701
+ super().__init__(c1, c2, n, shortcut, g, e)
702
+ c_ = int(c2 * e) # hidden channels
703
+ self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
704
+
705
+
706
+ class RepRes(Res):
707
+ # Standard bottleneck
708
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
709
+ super().__init__(c1, c2, shortcut, g, e)
710
+ c_ = int(c2 * e) # hidden channels
711
+ self.cv2 = RepConv(c_, c_, 3, 1, g=g)
712
+
713
+
714
+ class RepResCSPA(ResCSPA):
715
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
716
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
717
+ super().__init__(c1, c2, n, shortcut, g, e)
718
+ c_ = int(c2 * e) # hidden channels
719
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
720
+
721
+
722
+ class RepResCSPB(ResCSPB):
723
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
724
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
725
+ super().__init__(c1, c2, n, shortcut, g, e)
726
+ c_ = int(c2) # hidden channels
727
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
728
+
729
+
730
+ class RepResCSPC(ResCSPC):
731
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
732
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
733
+ super().__init__(c1, c2, n, shortcut, g, e)
734
+ c_ = int(c2 * e) # hidden channels
735
+ self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
736
+
737
+
738
+ class RepResX(ResX):
739
+ # Standard bottleneck
740
+ def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
741
+ super().__init__(c1, c2, shortcut, g, e)
742
+ c_ = int(c2 * e) # hidden channels
743
+ self.cv2 = RepConv(c_, c_, 3, 1, g=g)
744
+
745
+
746
+ class RepResXCSPA(ResXCSPA):
747
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
748
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
749
+ super().__init__(c1, c2, n, shortcut, g, e)
750
+ c_ = int(c2 * e) # hidden channels
751
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
752
+
753
+
754
+ class RepResXCSPB(ResXCSPB):
755
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
756
+ def __init__(self, c1, c2, n=1, shortcut=False, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
757
+ super().__init__(c1, c2, n, shortcut, g, e)
758
+ c_ = int(c2) # hidden channels
759
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
760
+
761
+
762
+ class RepResXCSPC(ResXCSPC):
763
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
764
+ def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
765
+ super().__init__(c1, c2, n, shortcut, g, e)
766
+ c_ = int(c2 * e) # hidden channels
767
+ self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
768
+
769
+ ##### end of repvgg #####
770
+
771
+
772
+ ##### transformer #####
773
+
774
+ class TransformerLayer(nn.Module):
775
+ # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
776
+ def __init__(self, c, num_heads):
777
+ super().__init__()
778
+ self.q = nn.Linear(c, c, bias=False)
779
+ self.k = nn.Linear(c, c, bias=False)
780
+ self.v = nn.Linear(c, c, bias=False)
781
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
782
+ self.fc1 = nn.Linear(c, c, bias=False)
783
+ self.fc2 = nn.Linear(c, c, bias=False)
784
+
785
+ def forward(self, x):
786
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
787
+ x = self.fc2(self.fc1(x)) + x
788
+ return x
789
+
790
+
791
+ class TransformerBlock(nn.Module):
792
+ # Vision Transformer https://arxiv.org/abs/2010.11929
793
+ def __init__(self, c1, c2, num_heads, num_layers):
794
+ super().__init__()
795
+ self.conv = None
796
+ if c1 != c2:
797
+ self.conv = Conv(c1, c2)
798
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
799
+ self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
800
+ self.c2 = c2
801
+
802
+ def forward(self, x):
803
+ if self.conv is not None:
804
+ x = self.conv(x)
805
+ b, _, w, h = x.shape
806
+ p = x.flatten(2)
807
+ p = p.unsqueeze(0)
808
+ p = p.transpose(0, 3)
809
+ p = p.squeeze(3)
810
+ e = self.linear(p)
811
+ x = p + e
812
+
813
+ x = self.tr(x)
814
+ x = x.unsqueeze(3)
815
+ x = x.transpose(0, 3)
816
+ x = x.reshape(b, self.c2, w, h)
817
+ return x
818
+
819
+ ##### end of transformer #####
820
+
821
+
822
+ ##### yolov5 #####
823
+
824
+ class Focus(nn.Module):
825
+ # Focus wh information into c-space
826
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
827
+ super(Focus, self).__init__()
828
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
829
+ # self.contract = Contract(gain=2)
830
+
831
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
832
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
833
+ # return self.conv(self.contract(x))
834
+
835
+
836
+ class SPPF(nn.Module):
837
+ # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
838
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
839
+ super().__init__()
840
+ c_ = c1 // 2 # hidden channels
841
+ self.cv1 = Conv(c1, c_, 1, 1)
842
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
843
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
844
+
845
+ def forward(self, x):
846
+ x = self.cv1(x)
847
+ y1 = self.m(x)
848
+ y2 = self.m(y1)
849
+ return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
850
+
851
+
852
+ class Contract(nn.Module):
853
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
854
+ def __init__(self, gain=2):
855
+ super().__init__()
856
+ self.gain = gain
857
+
858
+ def forward(self, x):
859
+ N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
860
+ s = self.gain
861
+ x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
862
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
863
+ return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
864
+
865
+
866
+ class Expand(nn.Module):
867
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
868
+ def __init__(self, gain=2):
869
+ super().__init__()
870
+ self.gain = gain
871
+
872
+ def forward(self, x):
873
+ N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
874
+ s = self.gain
875
+ x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
876
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
877
+ return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
878
+
879
+
880
+ class NMS(nn.Module):
881
+ # Non-Maximum Suppression (NMS) module
882
+ conf = 0.25 # confidence threshold
883
+ iou = 0.45 # IoU threshold
884
+ classes = None # (optional list) filter by class
885
+
886
+ def __init__(self):
887
+ super(NMS, self).__init__()
888
+
889
+ def forward(self, x):
890
+ return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
891
+
892
+
893
+ class autoShape(nn.Module):
894
+ # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
895
+ conf = 0.25 # NMS confidence threshold
896
+ iou = 0.45 # NMS IoU threshold
897
+ classes = None # (optional list) filter by class
898
+
899
+ def __init__(self, model):
900
+ super(autoShape, self).__init__()
901
+ self.model = model.eval()
902
+
903
+ def autoshape(self):
904
+ print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
905
+ return self
906
+
907
+ @torch.no_grad()
908
+ def forward(self, imgs, size=640, augment=False, profile=False):
909
+ # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
910
+ # filename: imgs = 'data/samples/zidane.jpg'
911
+ # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
912
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
913
+ # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
914
+ # numpy: = np.zeros((640,1280,3)) # HWC
915
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
916
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
917
+
918
+ t = [time_synchronized()]
919
+ p = next(self.model.parameters()) # for device and type
920
+ if isinstance(imgs, torch.Tensor): # torch
921
+ with amp.autocast(enabled=p.device.type != 'cpu'):
922
+ return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
923
+
924
+ # Pre-process
925
+ n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
926
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
927
+ for i, im in enumerate(imgs):
928
+ f = f'image{i}' # filename
929
+ if isinstance(im, str): # filename or uri
930
+ im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
931
+ elif isinstance(im, Image.Image): # PIL Image
932
+ im, f = np.asarray(im), getattr(im, 'filename', f) or f
933
+ files.append(Path(f).with_suffix('.jpg').name)
934
+ if im.shape[0] < 5: # image in CHW
935
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
936
+ im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
937
+ s = im.shape[:2] # HWC
938
+ shape0.append(s) # image shape
939
+ g = (size / max(s)) # gain
940
+ shape1.append([y * g for y in s])
941
+ imgs[i] = im # update
942
+ shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
943
+ x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
944
+ x = np.stack(x, 0) if n > 1 else x[0][None] # stack
945
+ x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
946
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
947
+ t.append(time_synchronized())
948
+
949
+ with amp.autocast(enabled=p.device.type != 'cpu'):
950
+ # Inference
951
+ y = self.model(x, augment, profile)[0] # forward
952
+ t.append(time_synchronized())
953
+
954
+ # Post-process
955
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
956
+ for i in range(n):
957
+ scale_coords(shape1, y[i][:, :4], shape0[i])
958
+
959
+ t.append(time_synchronized())
960
+ return Detections(imgs, y, files, t, self.names, x.shape)
961
+
962
+
963
+ class Detections:
964
+ # detections class for YOLOv5 inference results
965
+ def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
966
+ super(Detections, self).__init__()
967
+ d = pred[0].device # device
968
+ gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
969
+ self.imgs = imgs # list of images as numpy arrays
970
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
971
+ self.names = names # class names
972
+ self.files = files # image filenames
973
+ self.xyxy = pred # xyxy pixels
974
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
975
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
976
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
977
+ self.n = len(self.pred) # number of images (batch size)
978
+ self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
979
+ self.s = shape # inference BCHW shape
980
+
981
+ def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
982
+ colors = color_list()
983
+ for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
984
+ str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
985
+ if pred is not None:
986
+ for c in pred[:, -1].unique():
987
+ n = (pred[:, -1] == c).sum() # detections per class
988
+ str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
989
+ if show or save or render:
990
+ for *box, conf, cls in pred: # xyxy, confidence, class
991
+ label = f'{self.names[int(cls)]} {conf:.2f}'
992
+ plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
993
+ img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
994
+ if pprint:
995
+ print(str.rstrip(', '))
996
+ if show:
997
+ img.show(self.files[i]) # show
998
+ if save:
999
+ f = self.files[i]
1000
+ img.save(Path(save_dir) / f) # save
1001
+ print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
1002
+ if render:
1003
+ self.imgs[i] = np.asarray(img)
1004
+
1005
+ def print(self):
1006
+ self.display(pprint=True) # print results
1007
+ print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
1008
+
1009
+ def show(self):
1010
+ self.display(show=True) # show results
1011
+
1012
+ def save(self, save_dir='runs/hub/exp'):
1013
+ save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
1014
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
1015
+ self.display(save=True, save_dir=save_dir) # save results
1016
+
1017
+ def render(self):
1018
+ self.display(render=True) # render results
1019
+ return self.imgs
1020
+
1021
+ def pandas(self):
1022
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
1023
+ new = copy(self) # return copy
1024
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
1025
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
1026
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
1027
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
1028
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
1029
+ return new
1030
+
1031
+ def tolist(self):
1032
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
1033
+ x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
1034
+ for d in x:
1035
+ for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
1036
+ setattr(d, k, getattr(d, k)[0]) # pop out of list
1037
+ return x
1038
+
1039
+ def __len__(self):
1040
+ return self.n
1041
+
1042
+
1043
+ class Classify(nn.Module):
1044
+ # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
1045
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
1046
+ super(Classify, self).__init__()
1047
+ self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
1048
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
1049
+ self.flat = nn.Flatten()
1050
+
1051
+ def forward(self, x):
1052
+ z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
1053
+ return self.flat(self.conv(z)) # flatten to x(b,c2)
1054
+
1055
+ ##### end of yolov5 ######
1056
+
1057
+
1058
+ ##### orepa #####
1059
+
1060
+ def transI_fusebn(kernel, bn):
1061
+ gamma = bn.weight
1062
+ std = (bn.running_var + bn.eps).sqrt()
1063
+ return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
1064
+
1065
+
1066
+ class ConvBN(nn.Module):
1067
+ def __init__(self, in_channels, out_channels, kernel_size,
1068
+ stride=1, padding=0, dilation=1, groups=1, deploy=False, nonlinear=None):
1069
+ super().__init__()
1070
+ if nonlinear is None:
1071
+ self.nonlinear = nn.Identity()
1072
+ else:
1073
+ self.nonlinear = nonlinear
1074
+ if deploy:
1075
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
1076
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
1077
+ else:
1078
+ self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
1079
+ stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
1080
+ self.bn = nn.BatchNorm2d(num_features=out_channels)
1081
+
1082
+ def forward(self, x):
1083
+ if hasattr(self, 'bn'):
1084
+ return self.nonlinear(self.bn(self.conv(x)))
1085
+ else:
1086
+ return self.nonlinear(self.conv(x))
1087
+
1088
+ def switch_to_deploy(self):
1089
+ kernel, bias = transI_fusebn(self.conv.weight, self.bn)
1090
+ conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
1091
+ stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True)
1092
+ conv.weight.data = kernel
1093
+ conv.bias.data = bias
1094
+ for para in self.parameters():
1095
+ para.detach_()
1096
+ self.__delattr__('conv')
1097
+ self.__delattr__('bn')
1098
+ self.conv = conv
1099
+
1100
+ class OREPA_3x3_RepConv(nn.Module):
1101
+
1102
+ def __init__(self, in_channels, out_channels, kernel_size,
1103
+ stride=1, padding=0, dilation=1, groups=1,
1104
+ internal_channels_1x1_3x3=None,
1105
+ deploy=False, nonlinear=None, single_init=False):
1106
+ super(OREPA_3x3_RepConv, self).__init__()
1107
+ self.deploy = deploy
1108
+
1109
+ if nonlinear is None:
1110
+ self.nonlinear = nn.Identity()
1111
+ else:
1112
+ self.nonlinear = nonlinear
1113
+
1114
+ self.kernel_size = kernel_size
1115
+ self.in_channels = in_channels
1116
+ self.out_channels = out_channels
1117
+ self.groups = groups
1118
+ assert padding == kernel_size // 2
1119
+
1120
+ self.stride = stride
1121
+ self.padding = padding
1122
+ self.dilation = dilation
1123
+
1124
+ self.branch_counter = 0
1125
+
1126
+ self.weight_rbr_origin = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), kernel_size, kernel_size))
1127
+ nn.init.kaiming_uniform_(self.weight_rbr_origin, a=math.sqrt(1.0))
1128
+ self.branch_counter += 1
1129
+
1130
+
1131
+ if groups < out_channels:
1132
+ self.weight_rbr_avg_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
1133
+ self.weight_rbr_pfir_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
1134
+ nn.init.kaiming_uniform_(self.weight_rbr_avg_conv, a=1.0)
1135
+ nn.init.kaiming_uniform_(self.weight_rbr_pfir_conv, a=1.0)
1136
+ self.weight_rbr_avg_conv.data
1137
+ self.weight_rbr_pfir_conv.data
1138
+ self.register_buffer('weight_rbr_avg_avg', torch.ones(kernel_size, kernel_size).mul(1.0/kernel_size/kernel_size))
1139
+ self.branch_counter += 1
1140
+
1141
+ else:
1142
+ raise NotImplementedError
1143
+ self.branch_counter += 1
1144
+
1145
+ if internal_channels_1x1_3x3 is None:
1146
+ internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
1147
+
1148
+ if internal_channels_1x1_3x3 == in_channels:
1149
+ self.weight_rbr_1x1_kxk_idconv1 = nn.Parameter(torch.zeros(in_channels, int(in_channels/self.groups), 1, 1))
1150
+ id_value = np.zeros((in_channels, int(in_channels/self.groups), 1, 1))
1151
+ for i in range(in_channels):
1152
+ id_value[i, i % int(in_channels/self.groups), 0, 0] = 1
1153
+ id_tensor = torch.from_numpy(id_value).type_as(self.weight_rbr_1x1_kxk_idconv1)
1154
+ self.register_buffer('id_tensor', id_tensor)
1155
+
1156
+ else:
1157
+ self.weight_rbr_1x1_kxk_conv1 = nn.Parameter(torch.Tensor(internal_channels_1x1_3x3, int(in_channels/self.groups), 1, 1))
1158
+ nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv1, a=math.sqrt(1.0))
1159
+ self.weight_rbr_1x1_kxk_conv2 = nn.Parameter(torch.Tensor(out_channels, int(internal_channels_1x1_3x3/self.groups), kernel_size, kernel_size))
1160
+ nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv2, a=math.sqrt(1.0))
1161
+ self.branch_counter += 1
1162
+
1163
+ expand_ratio = 8
1164
+ self.weight_rbr_gconv_dw = nn.Parameter(torch.Tensor(in_channels*expand_ratio, 1, kernel_size, kernel_size))
1165
+ self.weight_rbr_gconv_pw = nn.Parameter(torch.Tensor(out_channels, in_channels*expand_ratio, 1, 1))
1166
+ nn.init.kaiming_uniform_(self.weight_rbr_gconv_dw, a=math.sqrt(1.0))
1167
+ nn.init.kaiming_uniform_(self.weight_rbr_gconv_pw, a=math.sqrt(1.0))
1168
+ self.branch_counter += 1
1169
+
1170
+ if out_channels == in_channels and stride == 1:
1171
+ self.branch_counter += 1
1172
+
1173
+ self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
1174
+ self.bn = nn.BatchNorm2d(out_channels)
1175
+
1176
+ self.fre_init()
1177
+
1178
+ nn.init.constant_(self.vector[0, :], 0.25) #origin
1179
+ nn.init.constant_(self.vector[1, :], 0.25) #avg
1180
+ nn.init.constant_(self.vector[2, :], 0.0) #prior
1181
+ nn.init.constant_(self.vector[3, :], 0.5) #1x1_kxk
1182
+ nn.init.constant_(self.vector[4, :], 0.5) #dws_conv
1183
+
1184
+
1185
+ def fre_init(self):
1186
+ prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, self.kernel_size)
1187
+ half_fg = self.out_channels/2
1188
+ for i in range(self.out_channels):
1189
+ for h in range(3):
1190
+ for w in range(3):
1191
+ if i < half_fg:
1192
+ prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
1193
+ else:
1194
+ prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_fg)/3)
1195
+
1196
+ self.register_buffer('weight_rbr_prior', prior_tensor)
1197
+
1198
+ def weight_gen(self):
1199
+
1200
+ weight_rbr_origin = torch.einsum('oihw,o->oihw', self.weight_rbr_origin, self.vector[0, :])
1201
+
1202
+ weight_rbr_avg = torch.einsum('oihw,o->oihw', torch.einsum('oihw,hw->oihw', self.weight_rbr_avg_conv, self.weight_rbr_avg_avg), self.vector[1, :])
1203
+
1204
+ weight_rbr_pfir = torch.einsum('oihw,o->oihw', torch.einsum('oihw,ohw->oihw', self.weight_rbr_pfir_conv, self.weight_rbr_prior), self.vector[2, :])
1205
+
1206
+ weight_rbr_1x1_kxk_conv1 = None
1207
+ if hasattr(self, 'weight_rbr_1x1_kxk_idconv1'):
1208
+ weight_rbr_1x1_kxk_conv1 = (self.weight_rbr_1x1_kxk_idconv1 + self.id_tensor).squeeze()
1209
+ elif hasattr(self, 'weight_rbr_1x1_kxk_conv1'):
1210
+ weight_rbr_1x1_kxk_conv1 = self.weight_rbr_1x1_kxk_conv1.squeeze()
1211
+ else:
1212
+ raise NotImplementedError
1213
+ weight_rbr_1x1_kxk_conv2 = self.weight_rbr_1x1_kxk_conv2
1214
+
1215
+ if self.groups > 1:
1216
+ g = self.groups
1217
+ t, ig = weight_rbr_1x1_kxk_conv1.size()
1218
+ o, tg, h, w = weight_rbr_1x1_kxk_conv2.size()
1219
+ weight_rbr_1x1_kxk_conv1 = weight_rbr_1x1_kxk_conv1.view(g, int(t/g), ig)
1220
+ weight_rbr_1x1_kxk_conv2 = weight_rbr_1x1_kxk_conv2.view(g, int(o/g), tg, h, w)
1221
+ weight_rbr_1x1_kxk = torch.einsum('gti,gothw->goihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2).view(o, ig, h, w)
1222
+ else:
1223
+ weight_rbr_1x1_kxk = torch.einsum('ti,othw->oihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2)
1224
+
1225
+ weight_rbr_1x1_kxk = torch.einsum('oihw,o->oihw', weight_rbr_1x1_kxk, self.vector[3, :])
1226
+
1227
+ weight_rbr_gconv = self.dwsc2full(self.weight_rbr_gconv_dw, self.weight_rbr_gconv_pw, self.in_channels)
1228
+ weight_rbr_gconv = torch.einsum('oihw,o->oihw', weight_rbr_gconv, self.vector[4, :])
1229
+
1230
+ weight = weight_rbr_origin + weight_rbr_avg + weight_rbr_1x1_kxk + weight_rbr_pfir + weight_rbr_gconv
1231
+
1232
+ return weight
1233
+
1234
+ def dwsc2full(self, weight_dw, weight_pw, groups):
1235
+
1236
+ t, ig, h, w = weight_dw.size()
1237
+ o, _, _, _ = weight_pw.size()
1238
+ tg = int(t/groups)
1239
+ i = int(ig*groups)
1240
+ weight_dw = weight_dw.view(groups, tg, ig, h, w)
1241
+ weight_pw = weight_pw.squeeze().view(o, groups, tg)
1242
+
1243
+ weight_dsc = torch.einsum('gtihw,ogt->ogihw', weight_dw, weight_pw)
1244
+ return weight_dsc.view(o, i, h, w)
1245
+
1246
+ def forward(self, inputs):
1247
+ weight = self.weight_gen()
1248
+ out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
1249
+
1250
+ return self.nonlinear(self.bn(out))
1251
+
1252
+ class RepConv_OREPA(nn.Module):
1253
+
1254
+ def __init__(self, c1, c2, k=3, s=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, nonlinear=nn.SiLU()):
1255
+ super(RepConv_OREPA, self).__init__()
1256
+ self.deploy = deploy
1257
+ self.groups = groups
1258
+ self.in_channels = c1
1259
+ self.out_channels = c2
1260
+
1261
+ self.padding = padding
1262
+ self.dilation = dilation
1263
+ self.groups = groups
1264
+
1265
+ assert k == 3
1266
+ assert padding == 1
1267
+
1268
+ padding_11 = padding - k // 2
1269
+
1270
+ if nonlinear is None:
1271
+ self.nonlinearity = nn.Identity()
1272
+ else:
1273
+ self.nonlinearity = nonlinear
1274
+
1275
+ if use_se:
1276
+ self.se = SEBlock(self.out_channels, internal_neurons=self.out_channels // 16)
1277
+ else:
1278
+ self.se = nn.Identity()
1279
+
1280
+ if deploy:
1281
+ self.rbr_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s,
1282
+ padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
1283
+
1284
+ else:
1285
+ self.rbr_identity = nn.BatchNorm2d(num_features=self.in_channels) if self.out_channels == self.in_channels and s == 1 else None
1286
+ self.rbr_dense = OREPA_3x3_RepConv(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s, padding=padding, groups=groups, dilation=1)
1287
+ self.rbr_1x1 = ConvBN(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=s, padding=padding_11, groups=groups, dilation=1)
1288
+ print('RepVGG Block, identity = ', self.rbr_identity)
1289
+
1290
+
1291
+ def forward(self, inputs):
1292
+ if hasattr(self, 'rbr_reparam'):
1293
+ return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
1294
+
1295
+ if self.rbr_identity is None:
1296
+ id_out = 0
1297
+ else:
1298
+ id_out = self.rbr_identity(inputs)
1299
+
1300
+ out1 = self.rbr_dense(inputs)
1301
+ out2 = self.rbr_1x1(inputs)
1302
+ out3 = id_out
1303
+ out = out1 + out2 + out3
1304
+
1305
+ return self.nonlinearity(self.se(out))
1306
+
1307
+
1308
+ # Optional. This improves the accuracy and facilitates quantization.
1309
+ # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
1310
+ # 2. Use like this.
1311
+ # loss = criterion(....)
1312
+ # for every RepVGGBlock blk:
1313
+ # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
1314
+ # optimizer.zero_grad()
1315
+ # loss.backward()
1316
+
1317
+ # Not used for OREPA
1318
+ def get_custom_L2(self):
1319
+ K3 = self.rbr_dense.weight_gen()
1320
+ K1 = self.rbr_1x1.conv.weight
1321
+ t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
1322
+ t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
1323
+
1324
+ l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
1325
+ eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
1326
+ l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
1327
+ return l2_loss_eq_kernel + l2_loss_circle
1328
+
1329
+ def get_equivalent_kernel_bias(self):
1330
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
1331
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
1332
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
1333
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
1334
+
1335
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
1336
+ if kernel1x1 is None:
1337
+ return 0
1338
+ else:
1339
+ return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
1340
+
1341
+ def _fuse_bn_tensor(self, branch):
1342
+ if branch is None:
1343
+ return 0, 0
1344
+ if not isinstance(branch, nn.BatchNorm2d):
1345
+ if isinstance(branch, OREPA_3x3_RepConv):
1346
+ kernel = branch.weight_gen()
1347
+ elif isinstance(branch, ConvBN):
1348
+ kernel = branch.conv.weight
1349
+ else:
1350
+ raise NotImplementedError
1351
+ running_mean = branch.bn.running_mean
1352
+ running_var = branch.bn.running_var
1353
+ gamma = branch.bn.weight
1354
+ beta = branch.bn.bias
1355
+ eps = branch.bn.eps
1356
+ else:
1357
+ if not hasattr(self, 'id_tensor'):
1358
+ input_dim = self.in_channels // self.groups
1359
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
1360
+ for i in range(self.in_channels):
1361
+ kernel_value[i, i % input_dim, 1, 1] = 1
1362
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
1363
+ kernel = self.id_tensor
1364
+ running_mean = branch.running_mean
1365
+ running_var = branch.running_var
1366
+ gamma = branch.weight
1367
+ beta = branch.bias
1368
+ eps = branch.eps
1369
+ std = (running_var + eps).sqrt()
1370
+ t = (gamma / std).reshape(-1, 1, 1, 1)
1371
+ return kernel * t, beta - running_mean * gamma / std
1372
+
1373
+ def switch_to_deploy(self):
1374
+ if hasattr(self, 'rbr_reparam'):
1375
+ return
1376
+ print(f"RepConv_OREPA.switch_to_deploy")
1377
+ kernel, bias = self.get_equivalent_kernel_bias()
1378
+ self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.in_channels, out_channels=self.rbr_dense.out_channels,
1379
+ kernel_size=self.rbr_dense.kernel_size, stride=self.rbr_dense.stride,
1380
+ padding=self.rbr_dense.padding, dilation=self.rbr_dense.dilation, groups=self.rbr_dense.groups, bias=True)
1381
+ self.rbr_reparam.weight.data = kernel
1382
+ self.rbr_reparam.bias.data = bias
1383
+ for para in self.parameters():
1384
+ para.detach_()
1385
+ self.__delattr__('rbr_dense')
1386
+ self.__delattr__('rbr_1x1')
1387
+ if hasattr(self, 'rbr_identity'):
1388
+ self.__delattr__('rbr_identity')
1389
+
1390
+ ##### end of orepa #####
1391
+
1392
+
1393
+ ##### swin transformer #####
1394
+
1395
+ class WindowAttention(nn.Module):
1396
+
1397
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
1398
+
1399
+ super().__init__()
1400
+ self.dim = dim
1401
+ self.window_size = window_size # Wh, Ww
1402
+ self.num_heads = num_heads
1403
+ head_dim = dim // num_heads
1404
+ self.scale = qk_scale or head_dim ** -0.5
1405
+
1406
+ # define a parameter table of relative position bias
1407
+ self.relative_position_bias_table = nn.Parameter(
1408
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
1409
+
1410
+ # get pair-wise relative position index for each token inside the window
1411
+ coords_h = torch.arange(self.window_size[0])
1412
+ coords_w = torch.arange(self.window_size[1])
1413
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1414
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1415
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
1416
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
1417
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
1418
+ relative_coords[:, :, 1] += self.window_size[1] - 1
1419
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
1420
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1421
+ self.register_buffer("relative_position_index", relative_position_index)
1422
+
1423
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
1424
+ self.attn_drop = nn.Dropout(attn_drop)
1425
+ self.proj = nn.Linear(dim, dim)
1426
+ self.proj_drop = nn.Dropout(proj_drop)
1427
+
1428
+ nn.init.normal_(self.relative_position_bias_table, std=.02)
1429
+ self.softmax = nn.Softmax(dim=-1)
1430
+
1431
+ def forward(self, x, mask=None):
1432
+
1433
+ B_, N, C = x.shape
1434
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
1435
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
1436
+
1437
+ q = q * self.scale
1438
+ attn = (q @ k.transpose(-2, -1))
1439
+
1440
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
1441
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
1442
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
1443
+ attn = attn + relative_position_bias.unsqueeze(0)
1444
+
1445
+ if mask is not None:
1446
+ nW = mask.shape[0]
1447
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
1448
+ attn = attn.view(-1, self.num_heads, N, N)
1449
+ attn = self.softmax(attn)
1450
+ else:
1451
+ attn = self.softmax(attn)
1452
+
1453
+ attn = self.attn_drop(attn)
1454
+
1455
+ # print(attn.dtype, v.dtype)
1456
+ try:
1457
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
1458
+ except:
1459
+ #print(attn.dtype, v.dtype)
1460
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
1461
+ x = self.proj(x)
1462
+ x = self.proj_drop(x)
1463
+ return x
1464
+
1465
+ class Mlp(nn.Module):
1466
+
1467
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
1468
+ super().__init__()
1469
+ out_features = out_features or in_features
1470
+ hidden_features = hidden_features or in_features
1471
+ self.fc1 = nn.Linear(in_features, hidden_features)
1472
+ self.act = act_layer()
1473
+ self.fc2 = nn.Linear(hidden_features, out_features)
1474
+ self.drop = nn.Dropout(drop)
1475
+
1476
+ def forward(self, x):
1477
+ x = self.fc1(x)
1478
+ x = self.act(x)
1479
+ x = self.drop(x)
1480
+ x = self.fc2(x)
1481
+ x = self.drop(x)
1482
+ return x
1483
+
1484
+ def window_partition(x, window_size):
1485
+
1486
+ B, H, W, C = x.shape
1487
+ assert H % window_size == 0, 'feature map h and w can not divide by window size'
1488
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1489
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
1490
+ return windows
1491
+
1492
+ def window_reverse(windows, window_size, H, W):
1493
+
1494
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
1495
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1496
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
1497
+ return x
1498
+
1499
+
1500
+ class SwinTransformerLayer(nn.Module):
1501
+
1502
+ def __init__(self, dim, num_heads, window_size=8, shift_size=0,
1503
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
1504
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
1505
+ super().__init__()
1506
+ self.dim = dim
1507
+ self.num_heads = num_heads
1508
+ self.window_size = window_size
1509
+ self.shift_size = shift_size
1510
+ self.mlp_ratio = mlp_ratio
1511
+ # if min(self.input_resolution) <= self.window_size:
1512
+ # # if window size is larger than input resolution, we don't partition windows
1513
+ # self.shift_size = 0
1514
+ # self.window_size = min(self.input_resolution)
1515
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
1516
+
1517
+ self.norm1 = norm_layer(dim)
1518
+ self.attn = WindowAttention(
1519
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
1520
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
1521
+
1522
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1523
+ self.norm2 = norm_layer(dim)
1524
+ mlp_hidden_dim = int(dim * mlp_ratio)
1525
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
1526
+
1527
+ def create_mask(self, H, W):
1528
+ # calculate attention mask for SW-MSA
1529
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
1530
+ h_slices = (slice(0, -self.window_size),
1531
+ slice(-self.window_size, -self.shift_size),
1532
+ slice(-self.shift_size, None))
1533
+ w_slices = (slice(0, -self.window_size),
1534
+ slice(-self.window_size, -self.shift_size),
1535
+ slice(-self.shift_size, None))
1536
+ cnt = 0
1537
+ for h in h_slices:
1538
+ for w in w_slices:
1539
+ img_mask[:, h, w, :] = cnt
1540
+ cnt += 1
1541
+
1542
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1543
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1544
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1545
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1546
+
1547
+ return attn_mask
1548
+
1549
+ def forward(self, x):
1550
+ # reshape x[b c h w] to x[b l c]
1551
+ _, _, H_, W_ = x.shape
1552
+
1553
+ Padding = False
1554
+ if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
1555
+ Padding = True
1556
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
1557
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
1558
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
1559
+ x = F.pad(x, (0, pad_r, 0, pad_b))
1560
+
1561
+ # print('2', x.shape)
1562
+ B, C, H, W = x.shape
1563
+ L = H * W
1564
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
1565
+
1566
+ # create mask from init to forward
1567
+ if self.shift_size > 0:
1568
+ attn_mask = self.create_mask(H, W).to(x.device)
1569
+ else:
1570
+ attn_mask = None
1571
+
1572
+ shortcut = x
1573
+ x = self.norm1(x)
1574
+ x = x.view(B, H, W, C)
1575
+
1576
+ # cyclic shift
1577
+ if self.shift_size > 0:
1578
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
1579
+ else:
1580
+ shifted_x = x
1581
+
1582
+ # partition windows
1583
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
1584
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
1585
+
1586
+ # W-MSA/SW-MSA
1587
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
1588
+
1589
+ # merge windows
1590
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
1591
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
1592
+
1593
+ # reverse cyclic shift
1594
+ if self.shift_size > 0:
1595
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
1596
+ else:
1597
+ x = shifted_x
1598
+ x = x.view(B, H * W, C)
1599
+
1600
+ # FFN
1601
+ x = shortcut + self.drop_path(x)
1602
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1603
+
1604
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
1605
+
1606
+ if Padding:
1607
+ x = x[:, :, :H_, :W_] # reverse padding
1608
+
1609
+ return x
1610
+
1611
+
1612
+ class SwinTransformerBlock(nn.Module):
1613
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
1614
+ super().__init__()
1615
+ self.conv = None
1616
+ if c1 != c2:
1617
+ self.conv = Conv(c1, c2)
1618
+
1619
+ # remove input_resolution
1620
+ self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
1621
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
1622
+
1623
+ def forward(self, x):
1624
+ if self.conv is not None:
1625
+ x = self.conv(x)
1626
+ x = self.blocks(x)
1627
+ return x
1628
+
1629
+
1630
+ class STCSPA(nn.Module):
1631
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1632
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1633
+ super(STCSPA, self).__init__()
1634
+ c_ = int(c2 * e) # hidden channels
1635
+ self.cv1 = Conv(c1, c_, 1, 1)
1636
+ self.cv2 = Conv(c1, c_, 1, 1)
1637
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1638
+ num_heads = c_ // 32
1639
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1640
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1641
+
1642
+ def forward(self, x):
1643
+ y1 = self.m(self.cv1(x))
1644
+ y2 = self.cv2(x)
1645
+ return self.cv3(torch.cat((y1, y2), dim=1))
1646
+
1647
+
1648
+ class STCSPB(nn.Module):
1649
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1650
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1651
+ super(STCSPB, self).__init__()
1652
+ c_ = int(c2) # hidden channels
1653
+ self.cv1 = Conv(c1, c_, 1, 1)
1654
+ self.cv2 = Conv(c_, c_, 1, 1)
1655
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
1656
+ num_heads = c_ // 32
1657
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1658
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1659
+
1660
+ def forward(self, x):
1661
+ x1 = self.cv1(x)
1662
+ y1 = self.m(x1)
1663
+ y2 = self.cv2(x1)
1664
+ return self.cv3(torch.cat((y1, y2), dim=1))
1665
+
1666
+
1667
+ class STCSPC(nn.Module):
1668
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1669
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1670
+ super(STCSPC, self).__init__()
1671
+ c_ = int(c2 * e) # hidden channels
1672
+ self.cv1 = Conv(c1, c_, 1, 1)
1673
+ self.cv2 = Conv(c1, c_, 1, 1)
1674
+ self.cv3 = Conv(c_, c_, 1, 1)
1675
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
1676
+ num_heads = c_ // 32
1677
+ self.m = SwinTransformerBlock(c_, c_, num_heads, n)
1678
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
1679
+
1680
+ def forward(self, x):
1681
+ y1 = self.cv3(self.m(self.cv1(x)))
1682
+ y2 = self.cv2(x)
1683
+ return self.cv4(torch.cat((y1, y2), dim=1))
1684
+
1685
+ ##### end of swin transformer #####
1686
+
1687
+
1688
+ ##### swin transformer v2 #####
1689
+
1690
+ class WindowAttention_v2(nn.Module):
1691
+
1692
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
1693
+ pretrained_window_size=[0, 0]):
1694
+
1695
+ super().__init__()
1696
+ self.dim = dim
1697
+ self.window_size = window_size # Wh, Ww
1698
+ self.pretrained_window_size = pretrained_window_size
1699
+ self.num_heads = num_heads
1700
+
1701
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
1702
+
1703
+ # mlp to generate continuous relative position bias
1704
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
1705
+ nn.ReLU(inplace=True),
1706
+ nn.Linear(512, num_heads, bias=False))
1707
+
1708
+ # get relative_coords_table
1709
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
1710
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
1711
+ relative_coords_table = torch.stack(
1712
+ torch.meshgrid([relative_coords_h,
1713
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
1714
+ if pretrained_window_size[0] > 0:
1715
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
1716
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
1717
+ else:
1718
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
1719
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
1720
+ relative_coords_table *= 8 # normalize to -8, 8
1721
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1722
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
1723
+
1724
+ self.register_buffer("relative_coords_table", relative_coords_table)
1725
+
1726
+ # get pair-wise relative position index for each token inside the window
1727
+ coords_h = torch.arange(self.window_size[0])
1728
+ coords_w = torch.arange(self.window_size[1])
1729
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1730
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1731
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
1732
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
1733
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
1734
+ relative_coords[:, :, 1] += self.window_size[1] - 1
1735
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
1736
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1737
+ self.register_buffer("relative_position_index", relative_position_index)
1738
+
1739
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
1740
+ if qkv_bias:
1741
+ self.q_bias = nn.Parameter(torch.zeros(dim))
1742
+ self.v_bias = nn.Parameter(torch.zeros(dim))
1743
+ else:
1744
+ self.q_bias = None
1745
+ self.v_bias = None
1746
+ self.attn_drop = nn.Dropout(attn_drop)
1747
+ self.proj = nn.Linear(dim, dim)
1748
+ self.proj_drop = nn.Dropout(proj_drop)
1749
+ self.softmax = nn.Softmax(dim=-1)
1750
+
1751
+ def forward(self, x, mask=None):
1752
+
1753
+ B_, N, C = x.shape
1754
+ qkv_bias = None
1755
+ if self.q_bias is not None:
1756
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
1757
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
1758
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
1759
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
1760
+
1761
+ # cosine attention
1762
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
1763
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
1764
+ attn = attn * logit_scale
1765
+
1766
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
1767
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
1768
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
1769
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
1770
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
1771
+ attn = attn + relative_position_bias.unsqueeze(0)
1772
+
1773
+ if mask is not None:
1774
+ nW = mask.shape[0]
1775
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
1776
+ attn = attn.view(-1, self.num_heads, N, N)
1777
+ attn = self.softmax(attn)
1778
+ else:
1779
+ attn = self.softmax(attn)
1780
+
1781
+ attn = self.attn_drop(attn)
1782
+
1783
+ try:
1784
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
1785
+ except:
1786
+ x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
1787
+
1788
+ x = self.proj(x)
1789
+ x = self.proj_drop(x)
1790
+ return x
1791
+
1792
+ def extra_repr(self) -> str:
1793
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
1794
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
1795
+
1796
+ def flops(self, N):
1797
+ # calculate flops for 1 window with token length of N
1798
+ flops = 0
1799
+ # qkv = self.qkv(x)
1800
+ flops += N * self.dim * 3 * self.dim
1801
+ # attn = (q @ k.transpose(-2, -1))
1802
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
1803
+ # x = (attn @ v)
1804
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
1805
+ # x = self.proj(x)
1806
+ flops += N * self.dim * self.dim
1807
+ return flops
1808
+
1809
+ class Mlp_v2(nn.Module):
1810
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
1811
+ super().__init__()
1812
+ out_features = out_features or in_features
1813
+ hidden_features = hidden_features or in_features
1814
+ self.fc1 = nn.Linear(in_features, hidden_features)
1815
+ self.act = act_layer()
1816
+ self.fc2 = nn.Linear(hidden_features, out_features)
1817
+ self.drop = nn.Dropout(drop)
1818
+
1819
+ def forward(self, x):
1820
+ x = self.fc1(x)
1821
+ x = self.act(x)
1822
+ x = self.drop(x)
1823
+ x = self.fc2(x)
1824
+ x = self.drop(x)
1825
+ return x
1826
+
1827
+
1828
+ def window_partition_v2(x, window_size):
1829
+
1830
+ B, H, W, C = x.shape
1831
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
1832
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
1833
+ return windows
1834
+
1835
+
1836
+ def window_reverse_v2(windows, window_size, H, W):
1837
+
1838
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
1839
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
1840
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
1841
+ return x
1842
+
1843
+
1844
+ class SwinTransformerLayer_v2(nn.Module):
1845
+
1846
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
1847
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
1848
+ act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
1849
+ super().__init__()
1850
+ self.dim = dim
1851
+ #self.input_resolution = input_resolution
1852
+ self.num_heads = num_heads
1853
+ self.window_size = window_size
1854
+ self.shift_size = shift_size
1855
+ self.mlp_ratio = mlp_ratio
1856
+ #if min(self.input_resolution) <= self.window_size:
1857
+ # # if window size is larger than input resolution, we don't partition windows
1858
+ # self.shift_size = 0
1859
+ # self.window_size = min(self.input_resolution)
1860
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
1861
+
1862
+ self.norm1 = norm_layer(dim)
1863
+ self.attn = WindowAttention_v2(
1864
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
1865
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
1866
+ pretrained_window_size=(pretrained_window_size, pretrained_window_size))
1867
+
1868
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1869
+ self.norm2 = norm_layer(dim)
1870
+ mlp_hidden_dim = int(dim * mlp_ratio)
1871
+ self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
1872
+
1873
+ def create_mask(self, H, W):
1874
+ # calculate attention mask for SW-MSA
1875
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
1876
+ h_slices = (slice(0, -self.window_size),
1877
+ slice(-self.window_size, -self.shift_size),
1878
+ slice(-self.shift_size, None))
1879
+ w_slices = (slice(0, -self.window_size),
1880
+ slice(-self.window_size, -self.shift_size),
1881
+ slice(-self.shift_size, None))
1882
+ cnt = 0
1883
+ for h in h_slices:
1884
+ for w in w_slices:
1885
+ img_mask[:, h, w, :] = cnt
1886
+ cnt += 1
1887
+
1888
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1889
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1890
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1891
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1892
+
1893
+ return attn_mask
1894
+
1895
+ def forward(self, x):
1896
+ # reshape x[b c h w] to x[b l c]
1897
+ _, _, H_, W_ = x.shape
1898
+
1899
+ Padding = False
1900
+ if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
1901
+ Padding = True
1902
+ # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
1903
+ pad_r = (self.window_size - W_ % self.window_size) % self.window_size
1904
+ pad_b = (self.window_size - H_ % self.window_size) % self.window_size
1905
+ x = F.pad(x, (0, pad_r, 0, pad_b))
1906
+
1907
+ # print('2', x.shape)
1908
+ B, C, H, W = x.shape
1909
+ L = H * W
1910
+ x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
1911
+
1912
+ # create mask from init to forward
1913
+ if self.shift_size > 0:
1914
+ attn_mask = self.create_mask(H, W).to(x.device)
1915
+ else:
1916
+ attn_mask = None
1917
+
1918
+ shortcut = x
1919
+ x = x.view(B, H, W, C)
1920
+
1921
+ # cyclic shift
1922
+ if self.shift_size > 0:
1923
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
1924
+ else:
1925
+ shifted_x = x
1926
+
1927
+ # partition windows
1928
+ x_windows = window_partition_v2(shifted_x, self.window_size) # nW*B, window_size, window_size, C
1929
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
1930
+
1931
+ # W-MSA/SW-MSA
1932
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
1933
+
1934
+ # merge windows
1935
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
1936
+ shifted_x = window_reverse_v2(attn_windows, self.window_size, H, W) # B H' W' C
1937
+
1938
+ # reverse cyclic shift
1939
+ if self.shift_size > 0:
1940
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
1941
+ else:
1942
+ x = shifted_x
1943
+ x = x.view(B, H * W, C)
1944
+ x = shortcut + self.drop_path(self.norm1(x))
1945
+
1946
+ # FFN
1947
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
1948
+ x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
1949
+
1950
+ if Padding:
1951
+ x = x[:, :, :H_, :W_] # reverse padding
1952
+
1953
+ return x
1954
+
1955
+ def extra_repr(self) -> str:
1956
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
1957
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
1958
+
1959
+ def flops(self):
1960
+ flops = 0
1961
+ H, W = self.input_resolution
1962
+ # norm1
1963
+ flops += self.dim * H * W
1964
+ # W-MSA/SW-MSA
1965
+ nW = H * W / self.window_size / self.window_size
1966
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
1967
+ # mlp
1968
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
1969
+ # norm2
1970
+ flops += self.dim * H * W
1971
+ return flops
1972
+
1973
+
1974
+ class SwinTransformer2Block(nn.Module):
1975
+ def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
1976
+ super().__init__()
1977
+ self.conv = None
1978
+ if c1 != c2:
1979
+ self.conv = Conv(c1, c2)
1980
+
1981
+ # remove input_resolution
1982
+ self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
1983
+ shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
1984
+
1985
+ def forward(self, x):
1986
+ if self.conv is not None:
1987
+ x = self.conv(x)
1988
+ x = self.blocks(x)
1989
+ return x
1990
+
1991
+
1992
+ class ST2CSPA(nn.Module):
1993
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
1994
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1995
+ super(ST2CSPA, self).__init__()
1996
+ c_ = int(c2 * e) # hidden channels
1997
+ self.cv1 = Conv(c1, c_, 1, 1)
1998
+ self.cv2 = Conv(c1, c_, 1, 1)
1999
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
2000
+ num_heads = c_ // 32
2001
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
2002
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
2003
+
2004
+ def forward(self, x):
2005
+ y1 = self.m(self.cv1(x))
2006
+ y2 = self.cv2(x)
2007
+ return self.cv3(torch.cat((y1, y2), dim=1))
2008
+
2009
+
2010
+ class ST2CSPB(nn.Module):
2011
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
2012
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
2013
+ super(ST2CSPB, self).__init__()
2014
+ c_ = int(c2) # hidden channels
2015
+ self.cv1 = Conv(c1, c_, 1, 1)
2016
+ self.cv2 = Conv(c_, c_, 1, 1)
2017
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
2018
+ num_heads = c_ // 32
2019
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
2020
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
2021
+
2022
+ def forward(self, x):
2023
+ x1 = self.cv1(x)
2024
+ y1 = self.m(x1)
2025
+ y2 = self.cv2(x1)
2026
+ return self.cv3(torch.cat((y1, y2), dim=1))
2027
+
2028
+
2029
+ class ST2CSPC(nn.Module):
2030
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
2031
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
2032
+ super(ST2CSPC, self).__init__()
2033
+ c_ = int(c2 * e) # hidden channels
2034
+ self.cv1 = Conv(c1, c_, 1, 1)
2035
+ self.cv2 = Conv(c1, c_, 1, 1)
2036
+ self.cv3 = Conv(c_, c_, 1, 1)
2037
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
2038
+ num_heads = c_ // 32
2039
+ self.m = SwinTransformer2Block(c_, c_, num_heads, n)
2040
+ #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
2041
+
2042
+ def forward(self, x):
2043
+ y1 = self.cv3(self.m(self.cv1(x)))
2044
+ y2 = self.cv2(x)
2045
+ return self.cv4(torch.cat((y1, y2), dim=1))
2046
+
2047
+ ##### end of swin transformer v2 #####
models/experimental.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models.common import Conv, DWConv
7
+ from utils.google_utils import attempt_download
8
+
9
+
10
+ class CrossConv(nn.Module):
11
+ # Cross Convolution Downsample
12
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
13
+ # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
14
+ super(CrossConv, self).__init__()
15
+ c_ = int(c2 * e) # hidden channels
16
+ self.cv1 = Conv(c1, c_, (1, k), (1, s))
17
+ self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
18
+ self.add = shortcut and c1 == c2
19
+
20
+ def forward(self, x):
21
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
22
+
23
+
24
+ class Sum(nn.Module):
25
+ # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
26
+ def __init__(self, n, weight=False): # n: number of inputs
27
+ super(Sum, self).__init__()
28
+ self.weight = weight # apply weights boolean
29
+ self.iter = range(n - 1) # iter object
30
+ if weight:
31
+ self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
32
+
33
+ def forward(self, x):
34
+ y = x[0] # no weight
35
+ if self.weight:
36
+ w = torch.sigmoid(self.w) * 2
37
+ for i in self.iter:
38
+ y = y + x[i + 1] * w[i]
39
+ else:
40
+ for i in self.iter:
41
+ y = y + x[i + 1]
42
+ return y
43
+
44
+
45
+ class MixConv2d(nn.Module):
46
+ # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
47
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
48
+ super(MixConv2d, self).__init__()
49
+ groups = len(k)
50
+ if equal_ch: # equal c_ per group
51
+ i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
52
+ c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
53
+ else: # equal weight.numel() per group
54
+ b = [c2] + [0] * groups
55
+ a = np.eye(groups + 1, groups, k=-1)
56
+ a -= np.roll(a, 1, axis=1)
57
+ a *= np.array(k) ** 2
58
+ a[0] = 1
59
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
60
+
61
+ self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
62
+ self.bn = nn.BatchNorm2d(c2)
63
+ self.act = nn.LeakyReLU(0.1, inplace=True)
64
+
65
+ def forward(self, x):
66
+ return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
67
+
68
+
69
+ class Ensemble(nn.ModuleList):
70
+ # Ensemble of models
71
+ def __init__(self):
72
+ super(Ensemble, self).__init__()
73
+
74
+ def forward(self, x, augment=False):
75
+ y = []
76
+ for module in self:
77
+ y.append(module(x, augment)[0])
78
+ # y = torch.stack(y).max(0)[0] # max ensemble
79
+ # y = torch.stack(y).mean(0) # mean ensemble
80
+ y = torch.cat(y, 1) # nms ensemble
81
+ return y, None # inference, train output
82
+
83
+
84
+
85
+
86
+
87
+ class ORT_NMS(torch.autograd.Function):
88
+ '''ONNX-Runtime NMS operation'''
89
+ @staticmethod
90
+ def forward(ctx,
91
+ boxes,
92
+ scores,
93
+ max_output_boxes_per_class=torch.tensor([100]),
94
+ iou_threshold=torch.tensor([0.45]),
95
+ score_threshold=torch.tensor([0.25])):
96
+ device = boxes.device
97
+ batch = scores.shape[0]
98
+ num_det = random.randint(0, 100)
99
+ batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
100
+ idxs = torch.arange(100, 100 + num_det).to(device)
101
+ zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
102
+ selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
103
+ selected_indices = selected_indices.to(torch.int64)
104
+ return selected_indices
105
+
106
+ @staticmethod
107
+ def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
108
+ return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
109
+
110
+
111
+ class TRT_NMS(torch.autograd.Function):
112
+ '''TensorRT NMS operation'''
113
+ @staticmethod
114
+ def forward(
115
+ ctx,
116
+ boxes,
117
+ scores,
118
+ background_class=-1,
119
+ box_coding=1,
120
+ iou_threshold=0.45,
121
+ max_output_boxes=100,
122
+ plugin_version="1",
123
+ score_activation=0,
124
+ score_threshold=0.25,
125
+ ):
126
+ batch_size, num_boxes, num_classes = scores.shape
127
+ num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
128
+ det_boxes = torch.randn(batch_size, max_output_boxes, 4)
129
+ det_scores = torch.randn(batch_size, max_output_boxes)
130
+ det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
131
+ return num_det, det_boxes, det_scores, det_classes
132
+
133
+ @staticmethod
134
+ def symbolic(g,
135
+ boxes,
136
+ scores,
137
+ background_class=-1,
138
+ box_coding=1,
139
+ iou_threshold=0.45,
140
+ max_output_boxes=100,
141
+ plugin_version="1",
142
+ score_activation=0,
143
+ score_threshold=0.25):
144
+ out = g.op("TRT::EfficientNMS_TRT",
145
+ boxes,
146
+ scores,
147
+ background_class_i=background_class,
148
+ box_coding_i=box_coding,
149
+ iou_threshold_f=iou_threshold,
150
+ max_output_boxes_i=max_output_boxes,
151
+ plugin_version_s=plugin_version,
152
+ score_activation_i=score_activation,
153
+ score_threshold_f=score_threshold,
154
+ outputs=4)
155
+ nums, boxes, scores, classes = out
156
+ return nums, boxes, scores, classes
157
+
158
+
159
+ class ONNX_ORT(nn.Module):
160
+ '''onnx module with ONNX-Runtime NMS operation.'''
161
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
162
+ super().__init__()
163
+ self.device = device if device else torch.device("cpu")
164
+ self.max_obj = torch.tensor([max_obj]).to(device)
165
+ self.iou_threshold = torch.tensor([iou_thres]).to(device)
166
+ self.score_threshold = torch.tensor([score_thres]).to(device)
167
+ self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
168
+ self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
169
+ dtype=torch.float32,
170
+ device=self.device)
171
+
172
+ def forward(self, x):
173
+ boxes = x[:, :, :4]
174
+ conf = x[:, :, 4:5]
175
+ scores = x[:, :, 5:]
176
+ scores *= conf
177
+ boxes @= self.convert_matrix
178
+ max_score, category_id = scores.max(2, keepdim=True)
179
+ dis = category_id.float() * self.max_wh
180
+ nmsbox = boxes + dis
181
+ max_score_tp = max_score.transpose(1, 2).contiguous()
182
+ selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
183
+ X, Y = selected_indices[:, 0], selected_indices[:, 2]
184
+ selected_boxes = boxes[X, Y, :]
185
+ selected_categories = category_id[X, Y, :].float()
186
+ selected_scores = max_score[X, Y, :]
187
+ X = X.unsqueeze(1).float()
188
+ return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
189
+
190
+ class ONNX_TRT(nn.Module):
191
+ '''onnx module with TensorRT NMS operation.'''
192
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
193
+ super().__init__()
194
+ assert max_wh is None
195
+ self.device = device if device else torch.device('cpu')
196
+ self.background_class = -1,
197
+ self.box_coding = 1,
198
+ self.iou_threshold = iou_thres
199
+ self.max_obj = max_obj
200
+ self.plugin_version = '1'
201
+ self.score_activation = 0
202
+ self.score_threshold = score_thres
203
+
204
+ def forward(self, x):
205
+ boxes = x[:, :, :4]
206
+ conf = x[:, :, 4:5]
207
+ scores = x[:, :, 5:]
208
+ scores *= conf
209
+ num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
210
+ self.iou_threshold, self.max_obj,
211
+ self.plugin_version, self.score_activation,
212
+ self.score_threshold)
213
+ return num_det, det_boxes, det_scores, det_classes
214
+
215
+
216
+ class End2End(nn.Module):
217
+ '''export onnx or tensorrt model with NMS operation.'''
218
+ def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
219
+ super().__init__()
220
+ device = device if device else torch.device('cpu')
221
+ assert isinstance(max_wh,(int)) or max_wh is None
222
+ self.model = model.to(device)
223
+ self.model.model[-1].end2end = True
224
+ self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
225
+ self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
226
+ self.end2end.eval()
227
+
228
+ def forward(self, x):
229
+ x = self.model(x)
230
+ x = self.end2end(x)
231
+ return x
232
+
233
+
234
+
235
+
236
+
237
+ def attempt_load(weights, map_location=None):
238
+ # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
239
+ model = Ensemble()
240
+ for w in weights if isinstance(weights, list) else [weights]:
241
+ attempt_download(w)
242
+ ckpt = torch.load(w, map_location=map_location) # load
243
+ model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
244
+
245
+ # Compatibility updates
246
+ for m in model.modules():
247
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
248
+ m.inplace = True # pytorch 1.7.0 compatibility
249
+ elif type(m) is nn.Upsample:
250
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
251
+ elif type(m) is Conv:
252
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
253
+
254
+ if len(model) == 1:
255
+ return model[-1] # return model
256
+ else:
257
+ print('Ensemble created with %s\n' % weights)
258
+ for k in ['names', 'stride']:
259
+ setattr(model, k, getattr(model[-1], k))
260
+ return model # return ensemble
261
+
262
+
models/yolo.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import sys
4
+ from copy import deepcopy
5
+
6
+ sys.path.append('./') # to run '$ python *.py' files in subdirectories
7
+ logger = logging.getLogger(__name__)
8
+ import torch
9
+ from models.common import *
10
+ from models.experimental import *
11
+ from utils.autoanchor import check_anchor_order
12
+ from utils.general import make_divisible, check_file, set_logging
13
+ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
14
+ select_device, copy_attr
15
+ from utils.loss import SigmoidBin
16
+
17
+ try:
18
+ import thop # for FLOPS computation
19
+ except ImportError:
20
+ thop = None
21
+
22
+
23
+ class Detect(nn.Module):
24
+ stride = None # strides computed during build
25
+ export = False # onnx export
26
+ end2end = False
27
+ include_nms = False
28
+ concat = False
29
+
30
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
31
+ super(Detect, self).__init__()
32
+ self.nc = nc # number of classes
33
+ self.no = nc + 5 # number of outputs per anchor
34
+ self.nl = len(anchors) # number of detection layers
35
+ self.na = len(anchors[0]) // 2 # number of anchors
36
+ self.grid = [torch.zeros(1)] * self.nl # init grid
37
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
38
+ self.register_buffer('anchors', a) # shape(nl,na,2)
39
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
40
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
41
+
42
+ def forward(self, x):
43
+ # x = x.copy() # for profiling
44
+ z = [] # inference output
45
+ self.training |= self.export
46
+ for i in range(self.nl):
47
+ x[i] = self.m[i](x[i]) # conv
48
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
49
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
50
+
51
+ if not self.training: # inference
52
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
53
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
54
+ y = x[i].sigmoid()
55
+ if not torch.onnx.is_in_onnx_export():
56
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
57
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
58
+ else:
59
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
60
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
61
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
62
+ y = torch.cat((xy, wh, conf), 4)
63
+ z.append(y.view(bs, -1, self.no))
64
+
65
+ if self.training:
66
+ out = x
67
+ elif self.end2end:
68
+ out = torch.cat(z, 1)
69
+ elif self.include_nms:
70
+ z = self.convert(z)
71
+ out = (z, )
72
+ elif self.concat:
73
+ out = torch.cat(z, 1)
74
+ else:
75
+ out = (torch.cat(z, 1), x)
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def _make_grid(nx=20, ny=20):
81
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
82
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
83
+
84
+ def convert(self, z):
85
+ z = torch.cat(z, 1)
86
+ box = z[:, :, :4]
87
+ conf = z[:, :, 4:5]
88
+ score = z[:, :, 5:]
89
+ score *= conf
90
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
91
+ dtype=torch.float32,
92
+ device=z.device)
93
+ box @= convert_matrix
94
+ return (box, score)
95
+
96
+
97
+ class IDetect(nn.Module):
98
+ stride = None # strides computed during build
99
+ export = False # onnx export
100
+ end2end = False
101
+ include_nms = False
102
+ concat = False
103
+
104
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
105
+ super(IDetect, self).__init__()
106
+ self.nc = nc # number of classes
107
+ self.no = nc + 5 # number of outputs per anchor
108
+ self.nl = len(anchors) # number of detection layers
109
+ self.na = len(anchors[0]) // 2 # number of anchors
110
+ self.grid = [torch.zeros(1)] * self.nl # init grid
111
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
112
+ self.register_buffer('anchors', a) # shape(nl,na,2)
113
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
114
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
115
+
116
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
117
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
118
+
119
+ def forward(self, x):
120
+ # x = x.copy() # for profiling
121
+ z = [] # inference output
122
+ self.training |= self.export
123
+ for i in range(self.nl):
124
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
125
+ x[i] = self.im[i](x[i])
126
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
127
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
128
+
129
+ if not self.training: # inference
130
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
131
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
132
+
133
+ y = x[i].sigmoid()
134
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
135
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
136
+ z.append(y.view(bs, -1, self.no))
137
+
138
+ return x if self.training else (torch.cat(z, 1), x)
139
+
140
+ def fuseforward(self, x):
141
+ # x = x.copy() # for profiling
142
+ z = [] # inference output
143
+ self.training |= self.export
144
+ for i in range(self.nl):
145
+ x[i] = self.m[i](x[i]) # conv
146
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
147
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
148
+
149
+ if not self.training: # inference
150
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
151
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
152
+
153
+ y = x[i].sigmoid()
154
+ if not torch.onnx.is_in_onnx_export():
155
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
156
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
157
+ else:
158
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
159
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
160
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
161
+ y = torch.cat((xy, wh, conf), 4)
162
+ z.append(y.view(bs, -1, self.no))
163
+
164
+ if self.training:
165
+ out = x
166
+ elif self.end2end:
167
+ out = torch.cat(z, 1)
168
+ elif self.include_nms:
169
+ z = self.convert(z)
170
+ out = (z, )
171
+ elif self.concat:
172
+ out = torch.cat(z, 1)
173
+ else:
174
+ out = (torch.cat(z, 1), x)
175
+
176
+ return out
177
+
178
+ def fuse(self):
179
+ print("IDetect.fuse")
180
+ # fuse ImplicitA and Convolution
181
+ for i in range(len(self.m)):
182
+ c1,c2,_,_ = self.m[i].weight.shape
183
+ c1_,c2_, _,_ = self.ia[i].implicit.shape
184
+ self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
185
+
186
+ # fuse ImplicitM and Convolution
187
+ for i in range(len(self.m)):
188
+ c1,c2, _,_ = self.im[i].implicit.shape
189
+ self.m[i].bias *= self.im[i].implicit.reshape(c2)
190
+ self.m[i].weight *= self.im[i].implicit.transpose(0,1)
191
+
192
+ @staticmethod
193
+ def _make_grid(nx=20, ny=20):
194
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
195
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
196
+
197
+ def convert(self, z):
198
+ z = torch.cat(z, 1)
199
+ box = z[:, :, :4]
200
+ conf = z[:, :, 4:5]
201
+ score = z[:, :, 5:]
202
+ score *= conf
203
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
204
+ dtype=torch.float32,
205
+ device=z.device)
206
+ box @= convert_matrix
207
+ return (box, score)
208
+
209
+
210
+ class IKeypoint(nn.Module):
211
+ stride = None # strides computed during build
212
+ export = False # onnx export
213
+
214
+ def __init__(self, nc=80, anchors=(), nkpt=17, ch=(), inplace=True, dw_conv_kpt=False): # detection layer
215
+ super(IKeypoint, self).__init__()
216
+ self.nc = nc # number of classes
217
+ self.nkpt = nkpt
218
+ self.dw_conv_kpt = dw_conv_kpt
219
+ self.no_det=(nc + 5) # number of outputs per anchor for box and class
220
+ self.no_kpt = 3*self.nkpt ## number of outputs per anchor for keypoints
221
+ self.no = self.no_det+self.no_kpt
222
+ self.nl = len(anchors) # number of detection layers
223
+ self.na = len(anchors[0]) // 2 # number of anchors
224
+ self.grid = [torch.zeros(1)] * self.nl # init grid
225
+ self.flip_test = False
226
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
227
+ self.register_buffer('anchors', a) # shape(nl,na,2)
228
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
229
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no_det * self.na, 1) for x in ch) # output conv
230
+
231
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
232
+ self.im = nn.ModuleList(ImplicitM(self.no_det * self.na) for _ in ch)
233
+
234
+ if self.nkpt is not None:
235
+ if self.dw_conv_kpt: #keypoint head is slightly more complex
236
+ self.m_kpt = nn.ModuleList(
237
+ nn.Sequential(DWConv(x, x, k=3), Conv(x,x),
238
+ DWConv(x, x, k=3), Conv(x, x),
239
+ DWConv(x, x, k=3), Conv(x,x),
240
+ DWConv(x, x, k=3), Conv(x, x),
241
+ DWConv(x, x, k=3), Conv(x, x),
242
+ DWConv(x, x, k=3), nn.Conv2d(x, self.no_kpt * self.na, 1)) for x in ch)
243
+ else: #keypoint head is a single convolution
244
+ self.m_kpt = nn.ModuleList(nn.Conv2d(x, self.no_kpt * self.na, 1) for x in ch)
245
+
246
+ self.inplace = inplace # use in-place ops (e.g. slice assignment)
247
+
248
+ def forward(self, x):
249
+ # x = x.copy() # for profiling
250
+ z = [] # inference output
251
+ self.training |= self.export
252
+ for i in range(self.nl):
253
+ if self.nkpt is None or self.nkpt==0:
254
+ x[i] = self.im[i](self.m[i](self.ia[i](x[i]))) # conv
255
+ else :
256
+ x[i] = torch.cat((self.im[i](self.m[i](self.ia[i](x[i]))), self.m_kpt[i](x[i])), axis=1)
257
+
258
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
259
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
260
+ x_det = x[i][..., :6]
261
+ x_kpt = x[i][..., 6:]
262
+
263
+ if not self.training: # inference
264
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
265
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
266
+ kpt_grid_x = self.grid[i][..., 0:1]
267
+ kpt_grid_y = self.grid[i][..., 1:2]
268
+
269
+ if self.nkpt == 0:
270
+ y = x[i].sigmoid()
271
+ else:
272
+ y = x_det.sigmoid()
273
+
274
+ if self.inplace:
275
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
276
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
277
+ if self.nkpt != 0:
278
+ x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy
279
+ x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy
280
+ #x_kpt[..., 0::3] = (x_kpt[..., ::3] + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy
281
+ #x_kpt[..., 1::3] = (x_kpt[..., 1::3] + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy
282
+ #print('=============')
283
+ #print(self.anchor_grid[i].shape)
284
+ #print(self.anchor_grid[i][...,0].unsqueeze(4).shape)
285
+ #print(x_kpt[..., 0::3].shape)
286
+ #x_kpt[..., 0::3] = ((x_kpt[..., 0::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i] # xy
287
+ #x_kpt[..., 1::3] = ((x_kpt[..., 1::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i] # xy
288
+ #x_kpt[..., 0::3] = (((x_kpt[..., 0::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i] # xy
289
+ #x_kpt[..., 1::3] = (((x_kpt[..., 1::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i] # xy
290
+ x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid()
291
+
292
+ y = torch.cat((xy, wh, y[..., 4:], x_kpt), dim = -1)
293
+
294
+ else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
295
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
296
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
297
+ if self.nkpt != 0:
298
+ y[..., 6:] = (y[..., 6:] * 2. - 0.5 + self.grid[i].repeat((1,1,1,1,self.nkpt))) * self.stride[i] # xy
299
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
300
+
301
+ z.append(y.view(bs, -1, self.no))
302
+
303
+ return x if self.training else (torch.cat(z, 1), x)
304
+
305
+ @staticmethod
306
+ def _make_grid(nx=20, ny=20):
307
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
308
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
309
+
310
+ class MT(nn.Module):
311
+ stride = None # strides computed during build
312
+ export = False # onnx export
313
+
314
+ def __init__(self, nc=80, anchors=(), attn=None, mask_iou=False, ch=()): # detection layer
315
+ super(MT, self).__init__()
316
+ self.nc = nc # number of classes
317
+ self.no = nc + 5 # number of outputs per anchor
318
+ self.nl = len(anchors) # number of detection layers
319
+ self.na = len(anchors[0]) // 2 # number of anchors
320
+ self.grid = [torch.zeros(1)] * self.nl # init grid
321
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
322
+ self.register_buffer('anchors', a) # shape(nl,na,2)
323
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
324
+ self.original_anchors = anchors
325
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[0]) # output conv
326
+ if mask_iou:
327
+ self.m_iou = nn.ModuleList(nn.Conv2d(x, self.na, 1) for x in ch[0]) # output con
328
+ self.mask_iou = mask_iou
329
+ self.attn = attn
330
+ if attn is not None:
331
+ # self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 3, padding=1) for x in ch) # output conv
332
+ self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, 1) for x in ch[0]) # output conv
333
+ #self.attn_m = nn.ModuleList(nn.Conv2d(x, attn * self.na, kernel_size=3, stride=1, padding=1) for x in ch) # output conv
334
+
335
+ def forward(self, x):
336
+ #print(x[1].shape)
337
+ #print(x[2].shape)
338
+ #print([a.shape for a in x])
339
+ #exit()
340
+ # x = x.copy() # for profiling
341
+ z = [] # inference output
342
+ za = []
343
+ zi = []
344
+ attn = [None] * self.nl
345
+ iou = [None] * self.nl
346
+ self.training |= self.export
347
+ output = dict()
348
+ for i in range(self.nl):
349
+ if self.attn is not None:
350
+ attn[i] = self.attn_m[i](x[0][i]) # conv
351
+ bs, _, ny, nx = attn[i].shape # x(bs,2352,20,20) to x(bs,3,20,20,784)
352
+ attn[i] = attn[i].view(bs, self.na, self.attn, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
353
+ if self.mask_iou:
354
+ iou[i] = self.m_iou[i](x[0][i])
355
+ x[0][i] = self.m[i](x[0][i]) # conv
356
+
357
+ bs, _, ny, nx = x[0][i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
358
+ x[0][i] = x[0][i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
359
+ if self.mask_iou:
360
+ iou[i] = iou[i].view(bs, self.na, ny, nx).contiguous()
361
+
362
+ if not self.training: # inference
363
+ za.append(attn[i].view(bs, -1, self.attn))
364
+ if self.mask_iou:
365
+ zi.append(iou[i].view(bs, -1))
366
+ if self.grid[i].shape[2:4] != x[0][i].shape[2:4]:
367
+ self.grid[i] = self._make_grid(nx, ny).to(x[0][i].device)
368
+
369
+ y = x[0][i].sigmoid()
370
+ y[..., 0:2] = (y[..., 0:2] * 3. - 1.0 + self.grid[i]) * self.stride[i] # xy
371
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
372
+ z.append(y.view(bs, -1, self.no))
373
+ output["mask_iou"] = None
374
+ if not self.training:
375
+ output["test"] = torch.cat(z, 1)
376
+ if self.attn is not None:
377
+ output["attn"] = torch.cat(za, 1)
378
+ if self.mask_iou:
379
+ output["mask_iou"] = torch.cat(zi, 1).sigmoid()
380
+
381
+ else:
382
+ if self.attn is not None:
383
+ output["attn"] = attn
384
+ if self.mask_iou:
385
+ output["mask_iou"] = iou
386
+ output["bbox_and_cls"] = x[0]
387
+ output["bases"] = x[1]
388
+ output["sem"] = x[2]
389
+
390
+ return output
391
+
392
+ @staticmethod
393
+ def _make_grid(nx=20, ny=20):
394
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
395
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
396
+
397
+
398
+ class IAuxDetect(nn.Module):
399
+ stride = None # strides computed during build
400
+ export = False # onnx export
401
+ end2end = False
402
+ include_nms = False
403
+ concat = False
404
+
405
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
406
+ super(IAuxDetect, self).__init__()
407
+ self.nc = nc # number of classes
408
+ self.no = nc + 5 # number of outputs per anchor
409
+ self.nl = len(anchors) # number of detection layers
410
+ self.na = len(anchors[0]) // 2 # number of anchors
411
+ self.grid = [torch.zeros(1)] * self.nl # init grid
412
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
413
+ self.register_buffer('anchors', a) # shape(nl,na,2)
414
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
415
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[:self.nl]) # output conv
416
+ self.m2 = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[self.nl:]) # output conv
417
+
418
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
419
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch[:self.nl])
420
+
421
+ def forward(self, x):
422
+ # x = x.copy() # for profiling
423
+ z = [] # inference output
424
+ self.training |= self.export
425
+ for i in range(self.nl):
426
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
427
+ x[i] = self.im[i](x[i])
428
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
429
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
430
+
431
+ x[i+self.nl] = self.m2[i](x[i+self.nl])
432
+ x[i+self.nl] = x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
433
+
434
+ if not self.training: # inference
435
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
436
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
437
+
438
+ y = x[i].sigmoid()
439
+ if not torch.onnx.is_in_onnx_export():
440
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
441
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
442
+ else:
443
+ xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
444
+ xy = xy * (2. * self.stride[i]) + (self.stride[i] * (self.grid[i] - 0.5)) # new xy
445
+ wh = wh ** 2 * (4 * self.anchor_grid[i].data) # new wh
446
+ y = torch.cat((xy, wh, conf), 4)
447
+ z.append(y.view(bs, -1, self.no))
448
+
449
+ return x if self.training else (torch.cat(z, 1), x[:self.nl])
450
+
451
+ def fuseforward(self, x):
452
+ # x = x.copy() # for profiling
453
+ z = [] # inference output
454
+ self.training |= self.export
455
+ for i in range(self.nl):
456
+ x[i] = self.m[i](x[i]) # conv
457
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
458
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
459
+
460
+ if not self.training: # inference
461
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
462
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
463
+
464
+ y = x[i].sigmoid()
465
+ if not torch.onnx.is_in_onnx_export():
466
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
467
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
468
+ else:
469
+ xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
470
+ wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh
471
+ y = torch.cat((xy, wh, y[..., 4:]), -1)
472
+ z.append(y.view(bs, -1, self.no))
473
+
474
+ if self.training:
475
+ out = x
476
+ elif self.end2end:
477
+ out = torch.cat(z, 1)
478
+ elif self.include_nms:
479
+ z = self.convert(z)
480
+ out = (z, )
481
+ elif self.concat:
482
+ out = torch.cat(z, 1)
483
+ else:
484
+ out = (torch.cat(z, 1), x)
485
+
486
+ return out
487
+
488
+ def fuse(self):
489
+ print("IAuxDetect.fuse")
490
+ # fuse ImplicitA and Convolution
491
+ for i in range(len(self.m)):
492
+ c1,c2,_,_ = self.m[i].weight.shape
493
+ c1_,c2_, _,_ = self.ia[i].implicit.shape
494
+ self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
495
+
496
+ # fuse ImplicitM and Convolution
497
+ for i in range(len(self.m)):
498
+ c1,c2, _,_ = self.im[i].implicit.shape
499
+ self.m[i].bias *= self.im[i].implicit.reshape(c2)
500
+ self.m[i].weight *= self.im[i].implicit.transpose(0,1)
501
+
502
+ @staticmethod
503
+ def _make_grid(nx=20, ny=20):
504
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
505
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
506
+
507
+ def convert(self, z):
508
+ z = torch.cat(z, 1)
509
+ box = z[:, :, :4]
510
+ conf = z[:, :, 4:5]
511
+ score = z[:, :, 5:]
512
+ score *= conf
513
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
514
+ dtype=torch.float32,
515
+ device=z.device)
516
+ box @= convert_matrix
517
+ return (box, score)
518
+
519
+
520
+ class IBin(nn.Module):
521
+ stride = None # strides computed during build
522
+ export = False # onnx export
523
+
524
+ def __init__(self, nc=80, anchors=(), ch=(), bin_count=21): # detection layer
525
+ super(IBin, self).__init__()
526
+ self.nc = nc # number of classes
527
+ self.bin_count = bin_count
528
+
529
+ self.w_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
530
+ self.h_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
531
+ # classes, x,y,obj
532
+ self.no = nc + 3 + \
533
+ self.w_bin_sigmoid.get_length() + self.h_bin_sigmoid.get_length() # w-bce, h-bce
534
+ # + self.x_bin_sigmoid.get_length() + self.y_bin_sigmoid.get_length()
535
+
536
+ self.nl = len(anchors) # number of detection layers
537
+ self.na = len(anchors[0]) // 2 # number of anchors
538
+ self.grid = [torch.zeros(1)] * self.nl # init grid
539
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
540
+ self.register_buffer('anchors', a) # shape(nl,na,2)
541
+ self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
542
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
543
+
544
+ self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
545
+ self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
546
+
547
+ def forward(self, x):
548
+
549
+ #self.x_bin_sigmoid.use_fw_regression = True
550
+ #self.y_bin_sigmoid.use_fw_regression = True
551
+ self.w_bin_sigmoid.use_fw_regression = True
552
+ self.h_bin_sigmoid.use_fw_regression = True
553
+
554
+ # x = x.copy() # for profiling
555
+ z = [] # inference output
556
+ self.training |= self.export
557
+ for i in range(self.nl):
558
+ x[i] = self.m[i](self.ia[i](x[i])) # conv
559
+ x[i] = self.im[i](x[i])
560
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
561
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
562
+
563
+ if not self.training: # inference
564
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
565
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
566
+
567
+ y = x[i].sigmoid()
568
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
569
+ #y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
570
+
571
+
572
+ #px = (self.x_bin_sigmoid.forward(y[..., 0:12]) + self.grid[i][..., 0]) * self.stride[i]
573
+ #py = (self.y_bin_sigmoid.forward(y[..., 12:24]) + self.grid[i][..., 1]) * self.stride[i]
574
+
575
+ pw = self.w_bin_sigmoid.forward(y[..., 2:24]) * self.anchor_grid[i][..., 0]
576
+ ph = self.h_bin_sigmoid.forward(y[..., 24:46]) * self.anchor_grid[i][..., 1]
577
+
578
+ #y[..., 0] = px
579
+ #y[..., 1] = py
580
+ y[..., 2] = pw
581
+ y[..., 3] = ph
582
+
583
+ y = torch.cat((y[..., 0:4], y[..., 46:]), dim=-1)
584
+
585
+ z.append(y.view(bs, -1, y.shape[-1]))
586
+
587
+ return x if self.training else (torch.cat(z, 1), x)
588
+
589
+ @staticmethod
590
+ def _make_grid(nx=20, ny=20):
591
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
592
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
593
+
594
+
595
+ class Model(nn.Module):
596
+ def __init__(self, cfg='yolor-csp-c.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
597
+ super(Model, self).__init__()
598
+ self.traced = False
599
+ if isinstance(cfg, dict):
600
+ self.yaml = cfg # model dict
601
+ else: # is *.yaml
602
+ import yaml # for torch hub
603
+ self.yaml_file = Path(cfg).name
604
+ with open(cfg) as f:
605
+ self.yaml = yaml.load(f, Loader=yaml.SafeLoader) # model dict
606
+
607
+ # Define model
608
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
609
+ if nc and nc != self.yaml['nc']:
610
+ logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
611
+ self.yaml['nc'] = nc # override yaml value
612
+ if anchors:
613
+ logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
614
+ self.yaml['anchors'] = round(anchors) # override yaml value
615
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
616
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
617
+ # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
618
+
619
+ # Build strides, anchors
620
+ m = self.model[-1] # Detect()
621
+ if isinstance(m, Detect):
622
+ s = 256 # 2x min stride
623
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
624
+ check_anchor_order(m)
625
+ m.anchors /= m.stride.view(-1, 1, 1)
626
+ self.stride = m.stride
627
+ self._initialize_biases() # only run once
628
+ # print('Strides: %s' % m.stride.tolist())
629
+ if isinstance(m, IDetect):
630
+ s = 256 # 2x min stride
631
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
632
+ check_anchor_order(m)
633
+ m.anchors /= m.stride.view(-1, 1, 1)
634
+ self.stride = m.stride
635
+ self._initialize_biases() # only run once
636
+ # print('Strides: %s' % m.stride.tolist())
637
+ if isinstance(m, IAuxDetect):
638
+ s = 256 # 2x min stride
639
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[:4]]) # forward
640
+ #print(m.stride)
641
+ check_anchor_order(m)
642
+ m.anchors /= m.stride.view(-1, 1, 1)
643
+ self.stride = m.stride
644
+ self._initialize_aux_biases() # only run once
645
+ # print('Strides: %s' % m.stride.tolist())
646
+ if isinstance(m, IBin):
647
+ s = 256 # 2x min stride
648
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
649
+ check_anchor_order(m)
650
+ m.anchors /= m.stride.view(-1, 1, 1)
651
+ self.stride = m.stride
652
+ self._initialize_biases_bin() # only run once
653
+ # print('Strides: %s' % m.stride.tolist())
654
+ if isinstance(m, IKeypoint):
655
+ s = 256 # 2x min stride
656
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
657
+ check_anchor_order(m)
658
+ m.anchors /= m.stride.view(-1, 1, 1)
659
+ self.stride = m.stride
660
+ self._initialize_biases_kpt() # only run once
661
+ # print('Strides: %s' % m.stride.tolist())
662
+ if isinstance(m, MT):
663
+ s = 256 # 2x min stride
664
+ temp = self.forward(torch.zeros(1, ch, s, s))
665
+ if isinstance(temp, list):
666
+ temp = temp[0]
667
+ m.stride = torch.tensor([s / x.shape[-2] for x in temp["bbox_and_cls"]]) # forward
668
+ check_anchor_order(m)
669
+ m.anchors /= m.stride.view(-1, 1, 1)
670
+ self.stride = m.stride
671
+ self._initialize_biases()
672
+
673
+ # Init weights, biases
674
+ initialize_weights(self)
675
+ self.info()
676
+ logger.info('')
677
+
678
+ def forward(self, x, augment=False, profile=False):
679
+ if augment:
680
+ img_size = x.shape[-2:] # height, width
681
+ s = [1, 0.83, 0.67] # scales
682
+ f = [None, 3, None] # flips (2-ud, 3-lr)
683
+ y = [] # outputs
684
+ for si, fi in zip(s, f):
685
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
686
+ yi = self.forward_once(xi)[0] # forward
687
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
688
+ yi[..., :4] /= si # de-scale
689
+ if fi == 2:
690
+ yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud
691
+ elif fi == 3:
692
+ yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr
693
+ y.append(yi)
694
+ return torch.cat(y, 1), None # augmented inference, train
695
+ else:
696
+ return self.forward_once(x, profile) # single-scale inference, train
697
+
698
+ def forward_once(self, x, profile=False):
699
+ y, dt = [], [] # outputs
700
+ for m in self.model:
701
+ if m.f != -1: # if not from previous layer
702
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
703
+
704
+ if not hasattr(self, 'traced'):
705
+ self.traced=False
706
+
707
+ if self.traced:
708
+ if isinstance(m, Detect) or isinstance(m, IDetect) or isinstance(m, IAuxDetect) or isinstance(m, IKeypoint):
709
+ break
710
+
711
+ if profile:
712
+ c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
713
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
714
+ for _ in range(10):
715
+ m(x.copy() if c else x)
716
+ t = time_synchronized()
717
+ for _ in range(10):
718
+ m(x.copy() if c else x)
719
+ dt.append((time_synchronized() - t) * 100)
720
+ print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
721
+
722
+ x = m(x) # run
723
+
724
+ y.append(x if m.i in self.save else None) # save output
725
+
726
+ if profile:
727
+ print('%.1fms total' % sum(dt))
728
+ return x
729
+
730
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
731
+ # https://arxiv.org/abs/1708.02002 section 3.3
732
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
733
+ m = self.model[-1] # Detect() module
734
+ for mi, s in zip(m.m, m.stride): # from
735
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
736
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
737
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
738
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
739
+
740
+ def _initialize_aux_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
741
+ # https://arxiv.org/abs/1708.02002 section 3.3
742
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
743
+ m = self.model[-1] # Detect() module
744
+ for mi, mi2, s in zip(m.m, m.m2, m.stride): # from
745
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
746
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
747
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
748
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
749
+ b2 = mi2.bias.view(m.na, -1) # conv.bias(255) to (3,85)
750
+ b2.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
751
+ b2.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
752
+ mi2.bias = torch.nn.Parameter(b2.view(-1), requires_grad=True)
753
+
754
+ def _initialize_biases_bin(self, cf=None): # initialize biases into Detect(), cf is class frequency
755
+ # https://arxiv.org/abs/1708.02002 section 3.3
756
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
757
+ m = self.model[-1] # Bin() module
758
+ bc = m.bin_count
759
+ for mi, s in zip(m.m, m.stride): # from
760
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
761
+ old = b[:, (0,1,2,bc+3)].data
762
+ obj_idx = 2*bc+4
763
+ b[:, :obj_idx].data += math.log(0.6 / (bc + 1 - 0.99))
764
+ b[:, obj_idx].data += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
765
+ b[:, (obj_idx+1):].data += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
766
+ b[:, (0,1,2,bc+3)].data = old
767
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
768
+
769
+ def _initialize_biases_kpt(self, cf=None): # initialize biases into Detect(), cf is class frequency
770
+ # https://arxiv.org/abs/1708.02002 section 3.3
771
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
772
+ m = self.model[-1] # Detect() module
773
+ for mi, s in zip(m.m, m.stride): # from
774
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
775
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
776
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
777
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
778
+
779
+ def _print_biases(self):
780
+ m = self.model[-1] # Detect() module
781
+ for mi in m.m: # from
782
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
783
+ print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
784
+
785
+ # def _print_weights(self):
786
+ # for m in self.model.modules():
787
+ # if type(m) is Bottleneck:
788
+ # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
789
+
790
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
791
+ print('Fusing layers... ')
792
+ for m in self.model.modules():
793
+ if isinstance(m, RepConv):
794
+ #print(f" fuse_repvgg_block")
795
+ m.fuse_repvgg_block()
796
+ elif isinstance(m, RepConv_OREPA):
797
+ #print(f" switch_to_deploy")
798
+ m.switch_to_deploy()
799
+ elif type(m) is Conv and hasattr(m, 'bn'):
800
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
801
+ delattr(m, 'bn') # remove batchnorm
802
+ m.forward = m.fuseforward # update forward
803
+ elif isinstance(m, (IDetect, IAuxDetect)):
804
+ m.fuse()
805
+ m.forward = m.fuseforward
806
+ self.info()
807
+ return self
808
+
809
+ def nms(self, mode=True): # add or remove NMS module
810
+ present = type(self.model[-1]) is NMS # last layer is NMS
811
+ if mode and not present:
812
+ print('Adding NMS... ')
813
+ m = NMS() # module
814
+ m.f = -1 # from
815
+ m.i = self.model[-1].i + 1 # index
816
+ self.model.add_module(name='%s' % m.i, module=m) # add
817
+ self.eval()
818
+ elif not mode and present:
819
+ print('Removing NMS... ')
820
+ self.model = self.model[:-1] # remove
821
+ return self
822
+
823
+ def autoshape(self): # add autoShape module
824
+ print('Adding autoShape... ')
825
+ m = autoShape(self) # wrap model
826
+ copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
827
+ return m
828
+
829
+ def info(self, verbose=False, img_size=640): # print model information
830
+ model_info(self, verbose, img_size)
831
+
832
+
833
+ def parse_model(d, ch): # model_dict, input_channels(3)
834
+ logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
835
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
836
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
837
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
838
+
839
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
840
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
841
+ m = eval(m) if isinstance(m, str) else m # eval strings
842
+ for j, a in enumerate(args):
843
+ try:
844
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
845
+ except:
846
+ pass
847
+
848
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
849
+ if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC,
850
+ SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv,
851
+ Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
852
+ RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
853
+ Res, ResCSPA, ResCSPB, ResCSPC,
854
+ RepRes, RepResCSPA, RepResCSPB, RepResCSPC,
855
+ ResX, ResXCSPA, ResXCSPB, ResXCSPC,
856
+ RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC,
857
+ Ghost, GhostCSPA, GhostCSPB, GhostCSPC,
858
+ SwinTransformerBlock, STCSPA, STCSPB, STCSPC,
859
+ SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC]:
860
+ c1, c2 = ch[f], args[0]
861
+ if c2 != no: # if not output
862
+ c2 = make_divisible(c2 * gw, 8)
863
+
864
+ args = [c1, c2, *args[1:]]
865
+ if m in [DownC, SPPCSPC, GhostSPPCSPC,
866
+ BottleneckCSPA, BottleneckCSPB, BottleneckCSPC,
867
+ RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,
868
+ ResCSPA, ResCSPB, ResCSPC,
869
+ RepResCSPA, RepResCSPB, RepResCSPC,
870
+ ResXCSPA, ResXCSPB, ResXCSPC,
871
+ RepResXCSPA, RepResXCSPB, RepResXCSPC,
872
+ GhostCSPA, GhostCSPB, GhostCSPC,
873
+ STCSPA, STCSPB, STCSPC,
874
+ ST2CSPA, ST2CSPB, ST2CSPC]:
875
+ args.insert(2, n) # number of repeats
876
+ n = 1
877
+ elif m is nn.BatchNorm2d:
878
+ args = [ch[f]]
879
+ elif m is Concat:
880
+ c2 = sum([ch[x] for x in f])
881
+ elif m is Chuncat:
882
+ c2 = sum([ch[x] for x in f])
883
+ elif m is Shortcut:
884
+ c2 = ch[f[0]]
885
+ elif m is Foldcut:
886
+ c2 = ch[f] // 2
887
+ elif m in [Detect, IDetect, IAuxDetect, IBin, IKeypoint]:
888
+ args.append([ch[x] for x in f])
889
+ if isinstance(args[1], int): # number of anchors
890
+ args[1] = [list(range(args[1] * 2))] * len(f)
891
+ elif m is ReOrg:
892
+ c2 = ch[f] * 4
893
+ elif m in [Merge]:
894
+ c2 = args[0]
895
+ elif m in [MT]:
896
+ if len(args) == 3:
897
+ args.append(False)
898
+ #print(f)
899
+ #print(len(ch))
900
+ #for x in f:
901
+ # print(ch[x])
902
+ args.append([ch[x] for x in f])
903
+ elif m is Contract:
904
+ c2 = ch[f] * args[0] ** 2
905
+ elif m is Expand:
906
+ c2 = ch[f] // args[0] ** 2
907
+ elif m is Refine:
908
+ args.append([ch[x] for x in f])
909
+ c2 = args[0]
910
+ else:
911
+ c2 = ch[f]
912
+
913
+ m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
914
+ t = str(m)[8:-2].replace('__main__.', '') # module type
915
+ np = sum([x.numel() for x in m_.parameters()]) # number params
916
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
917
+ logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
918
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
919
+ layers.append(m_)
920
+ if i == 0:
921
+ ch = []
922
+ ch.append(c2)
923
+ return nn.Sequential(*layers), sorted(save)
924
+
925
+
926
+ if __name__ == '__main__':
927
+ parser = argparse.ArgumentParser()
928
+ parser.add_argument('--cfg', type=str, default='yolor-csp-c.yaml', help='model.yaml')
929
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
930
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
931
+ opt = parser.parse_args()
932
+ opt.cfg = check_file(opt.cfg) # check file
933
+ set_logging()
934
+ device = select_device(opt.device)
935
+
936
+ # Create model
937
+ model = Model(opt.cfg).to(device)
938
+ model.train()
939
+
940
+ if opt.profile:
941
+ img = torch.rand(1, 3, 640, 640).to(device)
942
+ y = model(img, profile=True)
943
+
944
+ # Profile
945
+ # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
946
+ # y = model(img, profile=True)
947
+
948
+ # Tensorboard
949
+ # from torch.utils.tensorboard import SummaryWriter
950
+ # tb_writer = SummaryWriter()
951
+ # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
952
+ # tb_writer.add_graph(model.model, img) # add model to tensorboard
953
+ # tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard