Ding Yiwei
commited on
Commit
•
c2523be
1
Parent(s):
540ef0d
Replace 2 `transpose()` with 1 `permute` in TransformerBlock()` (#5645)
Browse files- models/common.py +2 -2
models/common.py
CHANGED
@@ -86,8 +86,8 @@ class TransformerBlock(nn.Module):
|
|
86 |
if self.conv is not None:
|
87 |
x = self.conv(x)
|
88 |
b, _, w, h = x.shape
|
89 |
-
p = x.flatten(2).
|
90 |
-
return self.tr(p + self.linear(p)).
|
91 |
|
92 |
|
93 |
class Bottleneck(nn.Module):
|
|
|
86 |
if self.conv is not None:
|
87 |
x = self.conv(x)
|
88 |
b, _, w, h = x.shape
|
89 |
+
p = x.flatten(2).permute(2, 0, 1)
|
90 |
+
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
|
91 |
|
92 |
|
93 |
class Bottleneck(nn.Module):
|