hysts HF staff commited on
Commit
6f3a230
1 Parent(s): 1489344

Fix type annotation

Browse files
Files changed (1) hide show
  1. model.py +4 -4
model.py CHANGED
@@ -6,8 +6,8 @@ import sys
6
 
7
  import PIL.Image
8
  import torch
9
- from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, PNDMPipeline,
10
- PNDMScheduler)
11
 
12
  HF_TOKEN = os.environ['HF_TOKEN']
13
 
@@ -39,7 +39,7 @@ class Model:
39
  self.scheduler_type)
40
 
41
  def _load_pipeline(self, model_name: str,
42
- scheduler_type: str) -> DDIMPipeline | DDPMPipeline:
43
  repo_id = f'hysts/diffusers-anime-faces-{model_name}'
44
  if scheduler_type == 'DDPM':
45
  pipeline = DDPMPipeline.from_pretrained(repo_id,
@@ -74,7 +74,7 @@ class Model:
74
 
75
  logger.info('--- done ---')
76
 
77
- def _download_all_models(self):
78
  for name in self.MODEL_NAMES:
79
  self._load_pipeline(name, 'DDPM')
80
 
 
6
 
7
  import PIL.Image
8
  import torch
9
+ from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
10
+ DiffusionPipeline, PNDMPipeline, PNDMScheduler)
11
 
12
  HF_TOKEN = os.environ['HF_TOKEN']
13
 
 
39
  self.scheduler_type)
40
 
41
  def _load_pipeline(self, model_name: str,
42
+ scheduler_type: str) -> DiffusionPipeline:
43
  repo_id = f'hysts/diffusers-anime-faces-{model_name}'
44
  if scheduler_type == 'DDPM':
45
  pipeline = DDPMPipeline.from_pretrained(repo_id,
 
74
 
75
  logger.info('--- done ---')
76
 
77
+ def _download_all_models(self) -> None:
78
  for name in self.MODEL_NAMES:
79
  self._load_pipeline(name, 'DDPM')
80