glenn-jocher commited on
Commit
6bd5e8b
1 Parent(s): c923fbf

nn.SiLU() export support (#1713)

Browse files
Files changed (2) hide show
  1. models/export.py +7 -4
  2. utils/activations.py +2 -2
models/export.py CHANGED
@@ -15,7 +15,7 @@ import torch.nn as nn
15
 
16
  import models
17
  from models.experimental import attempt_load
18
- from utils.activations import Hardswish
19
  from utils.general import set_logging, check_img_size
20
 
21
  if __name__ == '__main__':
@@ -43,9 +43,12 @@ if __name__ == '__main__':
43
  # Update model
44
  for k, m in model.named_modules():
45
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
46
- if isinstance(m, models.common.Conv) and isinstance(m.act, nn.Hardswish):
47
- m.act = Hardswish() # assign activation
48
- # if isinstance(m, models.yolo.Detect):
 
 
 
49
  # m.forward = m.forward_export # assign forward (optional)
50
  model.model[-1].export = True # set Detect() layer export=True
51
  y = model(img) # dry run
 
15
 
16
  import models
17
  from models.experimental import attempt_load
18
+ from utils.activations import Hardswish, SiLU
19
  from utils.general import set_logging, check_img_size
20
 
21
  if __name__ == '__main__':
 
43
  # Update model
44
  for k, m in model.named_modules():
45
  m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
46
+ if isinstance(m, models.common.Conv): # assign export-friendly activations
47
+ if isinstance(m.act, nn.Hardswish):
48
+ m.act = Hardswish()
49
+ elif isinstance(m.act, nn.SiLU):
50
+ m.act = SiLU()
51
+ # elif isinstance(m, models.yolo.Detect):
52
  # m.forward = m.forward_export # assign forward (optional)
53
  model.model[-1].export = True # set Detect() layer export=True
54
  y = model(img) # dry run
utils/activations.py CHANGED
@@ -5,8 +5,8 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
 
8
- # Swish https://arxiv.org/pdf/1905.02244.pdf ---------------------------------------------------------------------------
9
- class Swish(nn.Module): #
10
  @staticmethod
11
  def forward(x):
12
  return x * torch.sigmoid(x)
 
5
  import torch.nn.functional as F
6
 
7
 
8
+ # SiLU https://arxiv.org/pdf/1905.02244.pdf ----------------------------------------------------------------------------
9
+ class SiLU(nn.Module): # export-friendly version of nn.SiLU()
10
  @staticmethod
11
  def forward(x):
12
  return x * torch.sigmoid(x)