Ricercar commited on
Commit
8ec7aea
1 Parent(s): 45c5cd3

for debugging

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -7,6 +7,7 @@ import torchvision.transforms as T
7
 
8
  from clip_interrogator import Config, Interrogator
9
  from diffusers import StableDiffusionPipeline
 
10
 
11
  from ditail import DitailDemo, seed_everything
12
 
@@ -82,20 +83,21 @@ class WebApp():
82
  self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
83
  if not self.debug_mode:
84
  self.init_interrogator()
 
85
 
86
 
87
  def init_interrogator(self):
88
  cache_path = os.environ.get('HF_HOME')
89
- if cache_path:
90
- config = Config(cache_path=cache_path)
91
- else:
92
- config = Config()
93
  config.clip_model_name = self.args_base['clip_model_name']
94
  config.caption_model_name = self.args_base['caption_model_name']
95
  self.ci = Interrogator(config)
96
  self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
97
  self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
98
 
 
99
 
100
  def _preload_pipeline(self):
101
  for model in BASE_MODEL.values():
@@ -206,8 +208,9 @@ class WebApp():
206
 
207
  return ditail.run_ditail(), self.args_to_show
208
  # return self.args['img'], self.args
209
- except:
210
- print("Unknown error occurs")
 
211
 
212
  def run_example(self, img, prompt, inv_model, spl_model, lora):
213
  return self.run_ditail(img, prompt, spl_model, gr.State(lora), inv_model)
 
7
 
8
  from clip_interrogator import Config, Interrogator
9
  from diffusers import StableDiffusionPipeline
10
+ from transformers import file_utils
11
 
12
  from ditail import DitailDemo, seed_everything
13
 
 
83
  self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed
84
  if not self.debug_mode:
85
  self.init_interrogator()
86
+
87
 
88
 
89
  def init_interrogator(self):
90
  cache_path = os.environ.get('HF_HOME')
91
+ print(f"Intended cache dir: {cache_path}")
92
+ config = Config()
93
+ config.cache_path = cache_path
 
94
  config.clip_model_name = self.args_base['clip_model_name']
95
  config.caption_model_name = self.args_base['caption_model_name']
96
  self.ci = Interrogator(config)
97
  self.ci.config.chunk_size = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
98
  self.ci.config.flavor_intermediate_count = 2048 if self.ci.config.clip_model_name == "ViT-L-14/openai" else 1024
99
 
100
+ print(f"HF cache dir: {file_utils.default_cache_path}")
101
 
102
  def _preload_pipeline(self):
103
  for model in BASE_MODEL.values():
 
208
 
209
  return ditail.run_ditail(), self.args_to_show
210
  # return self.args['img'], self.args
211
+ except Exception as e:
212
+ print(f"Error catched: {e}")
213
+ gr.Markdown(f"**Error catched: {e}**")
214
 
215
  def run_example(self, img, prompt, inv_model, spl_model, lora):
216
  return self.run_ditail(img, prompt, spl_model, gr.State(lora), inv_model)