siriuszeina commited on
Commit
b689b5b
1 Parent(s): 277ad49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -45
app.py CHANGED
@@ -13,31 +13,26 @@ import numpy as np
13
  import PIL.Image
14
  import tensorflow as tf
15
 
16
- DESCRIPTION = '# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)'
17
 
18
 
19
  def load_sample_image_paths() -> list[pathlib.Path]:
20
- image_dir = pathlib.Path('images')
21
  if not image_dir.exists():
22
- path = huggingface_hub.hf_hub_download(
23
- 'public-data/sample-images-TADNE',
24
- 'images.tar.gz',
25
- repo_type='dataset')
26
  with tarfile.open(path) as f:
27
  f.extractall()
28
- return sorted(image_dir.glob('*'))
29
 
30
 
31
  def load_model() -> tf.keras.Model:
32
- path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
33
- 'model-resnet_custom_v3.h5')
34
  model = tf.keras.models.load_model(path)
35
  return model
36
 
37
 
38
  def load_labels() -> list[str]:
39
- path = huggingface_hub.hf_hub_download('public-data/DeepDanbooru',
40
- 'tags.txt')
41
  with open(path) as f:
42
  labels = [line.strip() for line in f.readlines()]
43
  return labels
@@ -47,18 +42,13 @@ model = load_model()
47
  labels = load_labels()
48
 
49
 
50
- def predict(
51
- image: PIL.Image.Image, score_threshold: float
52
- ) -> tuple[dict[str, float], dict[str, float], str]:
53
  _, height, width, _ = model.input_shape
54
  image = np.asarray(image)
55
- image = tf.image.resize(image,
56
- size=(height, width),
57
- method=tf.image.ResizeMethod.AREA,
58
- preserve_aspect_ratio=True)
59
  image = image.numpy()
60
  image = dd.image.transform_and_pad_image(image, width, height)
61
- image = image / 255.
62
  probs = model.predict(image[None, ...])[0]
63
  probs = probs.astype(float)
64
 
@@ -72,45 +62,35 @@ def predict(
72
  if prob < score_threshold:
73
  break
74
  result_threshold[label] = prob
75
- result_text = ', '.join(result_all.keys())
76
  return result_threshold, result_all, result_text
77
 
78
 
79
  image_paths = load_sample_image_paths()
80
  examples = [[path.as_posix(), 0.5] for path in image_paths]
81
 
82
- with gr.Blocks(css='style.css') as demo:
83
  gr.Markdown(DESCRIPTION)
84
  with gr.Row():
85
  with gr.Column():
86
- image = gr.Image(label='Input', type='pil')
87
- score_threshold = gr.Slider(label='Score threshold',
88
- minimum=0,
89
- maximum=1,
90
- step=0.05,
91
- value=0.5)
92
- run_button = gr.Button('Run')
93
  with gr.Column():
94
  with gr.Tabs():
95
- with gr.Tab(label='Output'):
96
- result = gr.Label(label='Output', show_label=False)
97
- with gr.Tab(label='JSON'):
98
- result_json = gr.JSON(label='JSON output',
99
- show_label=False)
100
- with gr.Tab(label='Text'):
101
- result_text = gr.Text(label='Text output',
102
- show_label=False,
103
- lines=5)
104
- # gr.Examples(examples=examples,
105
- # inputs=[image, score_threshold],
106
- # outputs=[result, result_json, result_text],
107
- # fn=predict,
108
- # cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
109
-
110
  run_button.click(
111
  fn=predict,
112
  inputs=[image, score_threshold],
113
  outputs=[result, result_json, result_text],
114
- api_name='predict',
115
  )
116
- demo.queue().launch()
 
 
 
13
  import PIL.Image
14
  import tensorflow as tf
15
 
16
+ DESCRIPTION = "# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)"
17
 
18
 
19
  def load_sample_image_paths() -> list[pathlib.Path]:
20
+ image_dir = pathlib.Path("images")
21
  if not image_dir.exists():
22
+ path = huggingface_hub.hf_hub_download("public-data/sample-images-TADNE", "images.tar.gz", repo_type="dataset")
 
 
 
23
  with tarfile.open(path) as f:
24
  f.extractall()
25
+ return sorted(image_dir.glob("*"))
26
 
27
 
28
  def load_model() -> tf.keras.Model:
29
+ path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5")
 
30
  model = tf.keras.models.load_model(path)
31
  return model
32
 
33
 
34
  def load_labels() -> list[str]:
35
+ path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "tags.txt")
 
36
  with open(path) as f:
37
  labels = [line.strip() for line in f.readlines()]
38
  return labels
 
42
  labels = load_labels()
43
 
44
 
45
+ def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
 
 
46
  _, height, width, _ = model.input_shape
47
  image = np.asarray(image)
48
+ image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
 
 
 
49
  image = image.numpy()
50
  image = dd.image.transform_and_pad_image(image, width, height)
51
+ image = image / 255.0
52
  probs = model.predict(image[None, ...])[0]
53
  probs = probs.astype(float)
54
 
 
62
  if prob < score_threshold:
63
  break
64
  result_threshold[label] = prob
65
+ result_text = ", ".join(result_all.keys())
66
  return result_threshold, result_all, result_text
67
 
68
 
69
  image_paths = load_sample_image_paths()
70
  examples = [[path.as_posix(), 0.5] for path in image_paths]
71
 
72
+ with gr.Blocks(css="style.css") as demo:
73
  gr.Markdown(DESCRIPTION)
74
  with gr.Row():
75
  with gr.Column():
76
+ image = gr.Image(label="Input", type="pil")
77
+ score_threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.05, value=0.5)
78
+ run_button = gr.Button("Run")
 
 
 
 
79
  with gr.Column():
80
  with gr.Tabs():
81
+ with gr.Tab(label="Output"):
82
+ result = gr.Label(label="Output", show_label=False)
83
+ with gr.Tab(label="JSON"):
84
+ result_json = gr.JSON(label="JSON output", show_label=False)
85
+ with gr.Tab(label="Text"):
86
+ result_text = gr.Text(label="Text output", show_label=False, lines=5)
87
+
 
 
 
 
 
 
 
 
88
  run_button.click(
89
  fn=predict,
90
  inputs=[image, score_threshold],
91
  outputs=[result, result_json, result_text],
92
+ api_name="predict",
93
  )
94
+
95
+ if __name__ == "__main__":
96
+ demo.queue(max_size=20).launch()