not-lain commited on
Commit
b52c99a
1 Parent(s): 56a56bd

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -1
train.py CHANGED
@@ -10,7 +10,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint
10
  from torch.utils.data import DataLoader
11
  from huggingface_hub import PyTorchModelHubMixin
12
 
13
- from isnet import ISNetDIS
14
 
15
 
16
  # warnings.filterwarnings("ignore")
@@ -22,6 +22,8 @@ def get_net(net_name, img_size):
22
  return ISNetDIS()
23
  elif net_name == "isnet_is":
24
  return ISNetDIS()
 
 
25
  raise NotImplementedError
26
 
27
 
 
10
  from torch.utils.data import DataLoader
11
  from huggingface_hub import PyTorchModelHubMixin
12
 
13
+ from isnet import ISNetDIS, ISNetGTEncoder
14
 
15
 
16
  # warnings.filterwarnings("ignore")
 
22
  return ISNetDIS()
23
  elif net_name == "isnet_is":
24
  return ISNetDIS()
25
+ elif net_name == "isnet_gt":
26
+ return ISNetGTEncoder()
27
  raise NotImplementedError
28
 
29