3f96a16 881b143 3f96a16 881b143 3f96a16 881b143
1
2
3
4
5
6
7
8
9
10
from torch import nn FC_CLASS_REGISTRY = {"torch": nn.Linear} try: import transformer_engine.pytorch as te FC_CLASS_REGISTRY["te"] = te.Linear except: pass