glenn-jocher commited on
Commit
9c6732f
1 Parent(s): 306fc01

Update variables (#4273)

Browse files
Files changed (1) hide show
  1. models/common.py +7 -7
models/common.py CHANGED
@@ -30,7 +30,7 @@ def autopad(k, p=None): # kernel, padding
30
 
31
 
32
  def DWConv(c1, c2, k=1, s=1, act=True):
33
- # Depthwise convolution
34
  return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
35
 
36
 
@@ -183,11 +183,11 @@ class Contract(nn.Module):
183
  self.gain = gain
184
 
185
  def forward(self, x):
186
- N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
187
  s = self.gain
188
- x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
189
  x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
190
- return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
191
 
192
 
193
  class Expand(nn.Module):
@@ -197,11 +197,11 @@ class Expand(nn.Module):
197
  self.gain = gain
198
 
199
  def forward(self, x):
200
- N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
201
  s = self.gain
202
- x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
203
  x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
204
- return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
205
 
206
 
207
  class Concat(nn.Module):
 
30
 
31
 
32
  def DWConv(c1, c2, k=1, s=1, act=True):
33
+ # Depth-wise convolution
34
  return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
35
 
36
 
 
183
  self.gain = gain
184
 
185
  def forward(self, x):
186
+ b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
187
  s = self.gain
188
+ x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
189
  x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
190
+ return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
191
 
192
 
193
  class Expand(nn.Module):
 
197
  self.gain = gain
198
 
199
  def forward(self, x):
200
+ b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
201
  s = self.gain
202
+ x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
203
  x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
204
+ return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
205
 
206
 
207
  class Concat(nn.Module):