adamelliotfields commited on
Commit
52bf5e0
1 Parent(s): 13b498b

Progress bar for more components

Browse files
Files changed (4) hide show
  1. lib/__init__.py +0 -2
  2. lib/inference.py +4 -2
  3. lib/loader.py +89 -33
  4. lib/utils.py +0 -11
lib/__init__.py CHANGED
@@ -10,7 +10,6 @@ from .utils import (
10
  download_repo_files,
11
  enable_progress_bars,
12
  load_json,
13
- progress_bar,
14
  read_file,
15
  timer,
16
  )
@@ -27,7 +26,6 @@ __all__ = [
27
  "enable_progress_bars",
28
  "generate",
29
  "load_json",
30
- "progress_bar",
31
  "read_file",
32
  "timer",
33
  ]
 
10
  download_repo_files,
11
  enable_progress_bars,
12
  load_json,
 
13
  read_file,
14
  timer,
15
  )
 
26
  "enable_progress_bars",
27
  "generate",
28
  "load_json",
 
29
  "read_file",
30
  "timer",
31
  ]
lib/inference.py CHANGED
@@ -16,7 +16,7 @@ from spaces import GPU
16
  from .config import Config
17
  from .loader import Loader
18
  from .logger import Logger
19
- from .utils import load_json, progress_bar, timer
20
 
21
 
22
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
@@ -301,8 +301,10 @@ def generate(
301
  image = pipe(**kwargs).images[0]
302
  if scale > 1:
303
  msg = f"Upscaling {scale}x"
304
- with timer(msg, logger=log.info), progress_bar(100, desc=msg, progress=progress):
 
305
  image = upscaler.predict(image)
 
306
  images.append((image, str(current_seed)))
307
  current_seed += 1
308
  except Exception as e:
 
16
  from .config import Config
17
  from .loader import Loader
18
  from .logger import Logger
19
+ from .utils import load_json, timer
20
 
21
 
22
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
 
301
  image = pipe(**kwargs).images[0]
302
  if scale > 1:
303
  msg = f"Upscaling {scale}x"
304
+ with timer(msg, logger=log.info):
305
+ progress((0, 100), desc=msg)
306
  image = upscaler.predict(image)
307
+ progress((100, 100), desc=msg)
308
  images.append((image, str(current_seed)))
309
  current_seed += 1
310
  except Exception as e:
lib/loader.py CHANGED
@@ -9,7 +9,7 @@ from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttn
9
  from .config import Config
10
  from .logger import Logger
11
  from .upscaler import RealESRGAN
12
- from .utils import progress_bar, timer
13
 
14
 
15
  class Loader:
@@ -41,6 +41,14 @@ class Loader:
41
  return issubclass(vae_type, AutoencoderTiny)
42
  return False
43
 
 
 
 
 
 
 
 
 
44
  def _should_unload_upscaler(self, scale=1):
45
  if self.upscaler is not None and self.upscaler.scale != scale:
46
  return True
@@ -84,6 +92,11 @@ class Loader:
84
  self.pipe.deepcache.disable()
85
  delattr(self.pipe, "deepcache")
86
 
 
 
 
 
 
87
  # Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
88
  def _unload_ip_adapter(self):
89
  if self.ip_adapter is not None:
@@ -110,11 +123,14 @@ class Loader:
110
  with timer(f"Unloading {self.model}", logger=self.log.info):
111
  self.pipe.to("cpu")
112
 
113
- def _unload(self, kind="", model="", ip_adapter="", deepcache=1, scale=1):
114
  to_unload = []
115
  if self._should_unload_deepcache(deepcache): # remove deepcache first
116
  self._unload_deepcache()
117
 
 
 
 
118
  if self._should_unload_upscaler(scale):
119
  self._unload_upscaler()
120
  to_unload.append("upscaler")
@@ -133,12 +149,35 @@ class Loader:
133
  setattr(self, component, None)
134
  gc.collect()
135
 
136
- def _load_upscaler(self, scale=1, progress=None):
137
  if self.upscaler is None and scale > 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
  msg = f"Loading {scale}x upscaler"
140
  # fmt: off
141
- with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
142
  self.upscaler = RealESRGAN(scale, device=self.pipe.device)
143
  self.upscaler.load_weights()
144
  # fmt: on
@@ -147,32 +186,22 @@ class Loader:
147
  self.upscaler = None
148
 
149
  def _load_deepcache(self, interval=1):
150
- has_deepcache = hasattr(self.pipe, "deepcache")
151
- if not has_deepcache and interval == 1:
152
- return
153
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
154
- return
155
- self.log.info("Enabling DeepCache")
156
- self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
157
- self.pipe.deepcache.set_params(cache_interval=interval)
158
- self.pipe.deepcache.enable()
159
 
160
  # https://github.com/ChenyangSi/FreeU
161
  def _load_freeu(self, freeu=False):
162
- block = self.pipe.unet.up_blocks[0]
163
- attrs = ["b1", "b2", "s1", "s2"]
164
- has_freeu = all(getattr(block, attr, None) is not None for attr in attrs)
165
- if has_freeu and not freeu:
166
- self.log.info("Disabling FreeU")
167
- self.pipe.disable_freeu()
168
- elif not has_freeu and freeu:
169
  self.log.info("Enabling FreeU")
170
  self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
171
 
172
- def _load_ip_adapter(self, ip_adapter="", progress=None):
173
- if not self.ip_adapter and ip_adapter:
174
  msg = "Loading IP-Adapter"
175
- with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
176
  self.pipe.load_ip_adapter(
177
  "h94/IP-Adapter",
178
  subfolder="models",
@@ -190,7 +219,7 @@ class Loader:
190
  **kwargs,
191
  ):
192
  pipeline = Config.PIPELINES[kind]
193
- if self.pipe is None:
194
  try:
195
  with timer(f"Loading {model} ({kind})", logger=self.log.info):
196
  self.model = model
@@ -212,11 +241,11 @@ class Loader:
212
  if self.pipe is not None:
213
  self.pipe.set_progress_bar_config(disable=progress is not None)
214
 
215
- def _load_vae(self, taesd=False, model="", progress=None):
216
  # by default all models use KL
217
  if self._is_kl_vae and taesd:
218
  msg = "Loading Tiny VAE"
219
- with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
220
  self.pipe.vae = AutoencoderTiny.from_pretrained(
221
  pretrained_model_name_or_path="madebyollin/taesd",
222
  torch_dtype=self.pipe.dtype,
@@ -225,7 +254,7 @@ class Loader:
225
 
226
  if self._is_tiny_vae and not taesd:
227
  msg = "Loading KL VAE"
228
- with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
229
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
230
  self.pipe.vae = AutoencoderKL.from_single_file(
231
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
@@ -299,7 +328,7 @@ class Loader:
299
  # defaults to float32
300
  pipe_kwargs["torch_dtype"] = torch.float16
301
 
302
- self._unload(kind, model, ip_adapter, deepcache, scale)
303
  self._load_pipeline(kind, model, progress, **pipe_kwargs)
304
 
305
  # error loading model
@@ -321,8 +350,35 @@ class Loader:
321
  if not same_scheduler or not same_karras:
322
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
323
 
324
- self._load_vae(taesd, model, progress)
325
- self._load_freeu(freeu)
326
- self._load_deepcache(deepcache)
327
- self._load_ip_adapter(ip_adapter, progress)
328
- self._load_upscaler(scale, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from .config import Config
10
  from .logger import Logger
11
  from .upscaler import RealESRGAN
12
+ from .utils import timer
13
 
14
 
15
  class Loader:
 
41
  return issubclass(vae_type, AutoencoderTiny)
42
  return False
43
 
44
+ @property
45
+ def _has_freeu(self):
46
+ if self.pipe is not None:
47
+ attrs = ["b1", "b2", "s1", "s2"]
48
+ block = self.pipe.unet.up_blocks[0]
49
+ return all(getattr(block, attr, None) is not None for attr in attrs)
50
+ return False
51
+
52
  def _should_unload_upscaler(self, scale=1):
53
  if self.upscaler is not None and self.upscaler.scale != scale:
54
  return True
 
92
  self.pipe.deepcache.disable()
93
  delattr(self.pipe, "deepcache")
94
 
95
+ def _unload_freeu(self, freeu=False):
96
+ if self._has_freeu and not freeu:
97
+ self.log.info("Disabling FreeU")
98
+ self.pipe.disable_freeu()
99
+
100
  # Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
101
  def _unload_ip_adapter(self):
102
  if self.ip_adapter is not None:
 
123
  with timer(f"Unloading {self.model}", logger=self.log.info):
124
  self.pipe.to("cpu")
125
 
126
+ def _unload(self, kind="", model="", ip_adapter="", deepcache=1, scale=1, freeu=False):
127
  to_unload = []
128
  if self._should_unload_deepcache(deepcache): # remove deepcache first
129
  self._unload_deepcache()
130
 
131
+ if self._has_freeu and not freeu:
132
+ self._unload_freeu()
133
+
134
  if self._should_unload_upscaler(scale):
135
  self._unload_upscaler()
136
  to_unload.append("upscaler")
 
149
  setattr(self, component, None)
150
  gc.collect()
151
 
152
+ def _should_load_upscaler(self, scale=1):
153
  if self.upscaler is None and scale > 1:
154
+ return True
155
+ return False
156
+
157
+ def _should_load_deepcache(self, interval=1):
158
+ has_deepcache = hasattr(self.pipe, "deepcache")
159
+ if not has_deepcache and interval != 1:
160
+ return True
161
+ if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
162
+ return True
163
+ return False
164
+
165
+ def _should_load_ip_adapter(self, ip_adapter=""):
166
+ if not self.ip_adapter and ip_adapter:
167
+ return True
168
+ return False
169
+
170
+ def _should_load_pipeline(self):
171
+ if self.pipe is None:
172
+ return True
173
+ return False
174
+
175
+ def _load_upscaler(self, scale=1):
176
+ if self._should_load_upscaler(scale):
177
  try:
178
  msg = f"Loading {scale}x upscaler"
179
  # fmt: off
180
+ with timer(msg, logger=self.log.info):
181
  self.upscaler = RealESRGAN(scale, device=self.pipe.device)
182
  self.upscaler.load_weights()
183
  # fmt: on
 
186
  self.upscaler = None
187
 
188
  def _load_deepcache(self, interval=1):
189
+ if self._should_load_deepcache(interval):
190
+ self.log.info("Enabling DeepCache")
191
+ self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
192
+ self.pipe.deepcache.set_params(cache_interval=interval)
193
+ self.pipe.deepcache.enable()
 
 
 
 
194
 
195
  # https://github.com/ChenyangSi/FreeU
196
  def _load_freeu(self, freeu=False):
197
+ if not self._has_freeu and freeu:
 
 
 
 
 
 
198
  self.log.info("Enabling FreeU")
199
  self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
200
 
201
+ def _load_ip_adapter(self, ip_adapter=""):
202
+ if self._should_load_ip_adapter(ip_adapter):
203
  msg = "Loading IP-Adapter"
204
+ with timer(msg, logger=self.log.info):
205
  self.pipe.load_ip_adapter(
206
  "h94/IP-Adapter",
207
  subfolder="models",
 
219
  **kwargs,
220
  ):
221
  pipeline = Config.PIPELINES[kind]
222
+ if self._should_load_pipeline():
223
  try:
224
  with timer(f"Loading {model} ({kind})", logger=self.log.info):
225
  self.model = model
 
241
  if self.pipe is not None:
242
  self.pipe.set_progress_bar_config(disable=progress is not None)
243
 
244
+ def _load_vae(self, taesd=False, model=""):
245
  # by default all models use KL
246
  if self._is_kl_vae and taesd:
247
  msg = "Loading Tiny VAE"
248
+ with timer(msg, logger=self.log.info):
249
  self.pipe.vae = AutoencoderTiny.from_pretrained(
250
  pretrained_model_name_or_path="madebyollin/taesd",
251
  torch_dtype=self.pipe.dtype,
 
254
 
255
  if self._is_tiny_vae and not taesd:
256
  msg = "Loading KL VAE"
257
+ with timer(msg, logger=self.log.info):
258
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
259
  self.pipe.vae = AutoencoderKL.from_single_file(
260
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
 
328
  # defaults to float32
329
  pipe_kwargs["torch_dtype"] = torch.float16
330
 
331
+ self._unload(kind, model, ip_adapter, deepcache, scale, freeu)
332
  self._load_pipeline(kind, model, progress, **pipe_kwargs)
333
 
334
  # error loading model
 
350
  if not same_scheduler or not same_karras:
351
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
352
 
353
+ CURRENT_STEP = 1
354
+ TOTAL_STEPS = sum(
355
+ [
356
+ self._is_kl_vae and taesd,
357
+ self._is_tiny_vae and not taesd,
358
+ not self._has_freeu and freeu,
359
+ self._should_load_deepcache(deepcache),
360
+ self._should_load_ip_adapter(ip_adapter),
361
+ self._should_load_upscaler(scale),
362
+ ]
363
+ )
364
+
365
+ msg = "Loading additional features"
366
+ if self._is_kl_vae and taesd or self._is_tiny_vae and not taesd:
367
+ self._load_vae(taesd, model)
368
+ progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
369
+ CURRENT_STEP += 1
370
+ if not self._has_freeu and freeu:
371
+ self._load_freeu(freeu)
372
+ progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
373
+ CURRENT_STEP += 1
374
+ if self._should_load_deepcache(deepcache):
375
+ self._load_deepcache(deepcache)
376
+ progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
377
+ CURRENT_STEP += 1
378
+ if self._should_load_ip_adapter(ip_adapter):
379
+ self._load_ip_adapter(ip_adapter)
380
+ progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
381
+ CURRENT_STEP += 1
382
+ if self._should_load_upscaler(scale):
383
+ self._load_upscaler(scale)
384
+ progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
lib/utils.py CHANGED
@@ -35,17 +35,6 @@ def timer(message="Operation", logger=print):
35
  logger(f"{message} took {end - start:.2f}s")
36
 
37
 
38
- @contextmanager
39
- def progress_bar(total, desc="Loading", progress=None):
40
- if progress is None:
41
- yield
42
- try:
43
- progress((0, total), desc=desc)
44
- yield
45
- finally:
46
- progress((total, total), desc=desc)
47
-
48
-
49
  @functools.lru_cache()
50
  def load_json(path: str) -> dict:
51
  with open(path, "r", encoding="utf-8") as file:
 
35
  logger(f"{message} took {end - start:.2f}s")
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  @functools.lru_cache()
39
  def load_json(path: str) -> dict:
40
  with open(path, "r", encoding="utf-8") as file: