Spaces:
Runtime error
Runtime error
Fix type annotation
Browse files
model.py
CHANGED
@@ -6,8 +6,8 @@ import sys
|
|
6 |
|
7 |
import PIL.Image
|
8 |
import torch
|
9 |
-
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
|
10 |
-
PNDMScheduler)
|
11 |
|
12 |
HF_TOKEN = os.environ['HF_TOKEN']
|
13 |
|
@@ -39,7 +39,7 @@ class Model:
|
|
39 |
self.scheduler_type)
|
40 |
|
41 |
def _load_pipeline(self, model_name: str,
|
42 |
-
scheduler_type: str) ->
|
43 |
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
|
44 |
if scheduler_type == 'DDPM':
|
45 |
pipeline = DDPMPipeline.from_pretrained(repo_id,
|
@@ -74,7 +74,7 @@ class Model:
|
|
74 |
|
75 |
logger.info('--- done ---')
|
76 |
|
77 |
-
def _download_all_models(self):
|
78 |
for name in self.MODEL_NAMES:
|
79 |
self._load_pipeline(name, 'DDPM')
|
80 |
|
|
|
6 |
|
7 |
import PIL.Image
|
8 |
import torch
|
9 |
+
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
|
10 |
+
DiffusionPipeline, PNDMPipeline, PNDMScheduler)
|
11 |
|
12 |
HF_TOKEN = os.environ['HF_TOKEN']
|
13 |
|
|
|
39 |
self.scheduler_type)
|
40 |
|
41 |
def _load_pipeline(self, model_name: str,
|
42 |
+
scheduler_type: str) -> DiffusionPipeline:
|
43 |
repo_id = f'hysts/diffusers-anime-faces-{model_name}'
|
44 |
if scheduler_type == 'DDPM':
|
45 |
pipeline = DDPMPipeline.from_pretrained(repo_id,
|
|
|
74 |
|
75 |
logger.info('--- done ---')
|
76 |
|
77 |
+
def _download_all_models(self) -> None:
|
78 |
for name in self.MODEL_NAMES:
|
79 |
self._load_pipeline(name, 'DDPM')
|
80 |
|