glenn-jocher
commited on
Commit
•
6bd5e8b
1
Parent(s):
c923fbf
nn.SiLU() export support (#1713)
Browse files- models/export.py +7 -4
- utils/activations.py +2 -2
models/export.py
CHANGED
@@ -15,7 +15,7 @@ import torch.nn as nn
|
|
15 |
|
16 |
import models
|
17 |
from models.experimental import attempt_load
|
18 |
-
from utils.activations import Hardswish
|
19 |
from utils.general import set_logging, check_img_size
|
20 |
|
21 |
if __name__ == '__main__':
|
@@ -43,9 +43,12 @@ if __name__ == '__main__':
|
|
43 |
# Update model
|
44 |
for k, m in model.named_modules():
|
45 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
46 |
-
if isinstance(m, models.common.Conv)
|
47 |
-
m.act
|
48 |
-
|
|
|
|
|
|
|
49 |
# m.forward = m.forward_export # assign forward (optional)
|
50 |
model.model[-1].export = True # set Detect() layer export=True
|
51 |
y = model(img) # dry run
|
|
|
15 |
|
16 |
import models
|
17 |
from models.experimental import attempt_load
|
18 |
+
from utils.activations import Hardswish, SiLU
|
19 |
from utils.general import set_logging, check_img_size
|
20 |
|
21 |
if __name__ == '__main__':
|
|
|
43 |
# Update model
|
44 |
for k, m in model.named_modules():
|
45 |
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
46 |
+
if isinstance(m, models.common.Conv): # assign export-friendly activations
|
47 |
+
if isinstance(m.act, nn.Hardswish):
|
48 |
+
m.act = Hardswish()
|
49 |
+
elif isinstance(m.act, nn.SiLU):
|
50 |
+
m.act = SiLU()
|
51 |
+
# elif isinstance(m, models.yolo.Detect):
|
52 |
# m.forward = m.forward_export # assign forward (optional)
|
53 |
model.model[-1].export = True # set Detect() layer export=True
|
54 |
y = model(img) # dry run
|
utils/activations.py
CHANGED
@@ -5,8 +5,8 @@ import torch.nn as nn
|
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
|
8 |
-
#
|
9 |
-
class
|
10 |
@staticmethod
|
11 |
def forward(x):
|
12 |
return x * torch.sigmoid(x)
|
|
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
|
8 |
+
# SiLU https://arxiv.org/pdf/1905.02244.pdf ----------------------------------------------------------------------------
|
9 |
+
class SiLU(nn.Module): # export-friendly version of nn.SiLU()
|
10 |
@staticmethod
|
11 |
def forward(x):
|
12 |
return x * torch.sigmoid(x)
|