siriuszeina commited on
Commit
e73ed2c
1 Parent(s): 95e81ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -17,7 +17,13 @@ import tensorflow as tf
17
 
18
  DESCRIPTION = "# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)"
19
 
20
-
 
 
 
 
 
 
21
 
22
  def load_model() -> tf.keras.Model:
23
  path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5")
@@ -61,7 +67,8 @@ def predict(url: str, score_threshold: float) -> tuple[dict[str, float], dict[st
61
  result_text = ", ".join(result_all.keys())
62
  return result_threshold, result_all, result_text
63
 
64
-
 
65
 
66
  with gr.Blocks(css="style.css") as demo:
67
  gr.Markdown(DESCRIPTION)
 
17
 
18
  DESCRIPTION = "# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)"
19
 
20
+ def load_sample_image_paths() -> list[pathlib.Path]:
21
+ image_dir = pathlib.Path("images")
22
+ if not image_dir.exists():
23
+ path = huggingface_hub.hf_hub_download("public-data/sample-images-TADNE", "images.tar.gz", repo_type="dataset")
24
+ with tarfile.open(path) as f:
25
+ f.extractall()
26
+ return sorted(image_dir.glob("*"))
27
 
28
  def load_model() -> tf.keras.Model:
29
  path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5")
 
67
  result_text = ", ".join(result_all.keys())
68
  return result_threshold, result_all, result_text
69
 
70
+ image_paths = load_sample_image_paths()
71
+ examples = [[path.as_posix(), 0.5] for path in image_paths]
72
 
73
  with gr.Blocks(css="style.css") as demo:
74
  gr.Markdown(DESCRIPTION)