kaizen9's picture
Upload model checkpoints directly from S3
958d6f8 verified
raw
history blame contribute delete
No virus
196 Bytes
from torch import nn
from .layers_registry import fcs
fcs.register('torch', func=nn.Linear)
try:
import transformer_engine.pytorch as te
fcs.register('te', func=te.Linear)
except:
pass