Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|
14 |
|
15 |
from src.my_utils.testing_utils import parse_args_paired_testing
|
16 |
from src.de_net import DEResNet
|
17 |
-
from s3diff_tile import S3Diff
|
18 |
from torchvision import transforms
|
19 |
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
|
20 |
|
@@ -24,9 +24,12 @@ tensor_transforms = transforms.Compose([
|
|
24 |
|
25 |
args = parse_args_paired_testing()
|
26 |
|
|
|
|
|
|
|
27 |
# Load scheduler, tokenizer and models.
|
28 |
-
pretrained_model_path = '
|
29 |
-
t2i_path = 'sd-turbo
|
30 |
de_net_path = 'assets/mm-realsr/de_net.pth'
|
31 |
|
32 |
# initialize net_sr
|
|
|
14 |
|
15 |
from src.my_utils.testing_utils import parse_args_paired_testing
|
16 |
from src.de_net import DEResNet
|
17 |
+
from src.s3diff_tile import S3Diff
|
18 |
from torchvision import transforms
|
19 |
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
|
20 |
|
|
|
24 |
|
25 |
args = parse_args_paired_testing()
|
26 |
|
27 |
+
# Run the script to get pretrained models
|
28 |
+
subprocess.run(["bash", "get_pretrained_models.sh"])
|
29 |
+
|
30 |
# Load scheduler, tokenizer and models.
|
31 |
+
pretrained_model_path = 'checkpoints/s3diff.pkl'
|
32 |
+
t2i_path = 'stabilityai/sd-turbo'
|
33 |
de_net_path = 'assets/mm-realsr/de_net.pth'
|
34 |
|
35 |
# initialize net_sr
|