File size: 2,816 Bytes
fb59cb8
 
 
 
 
 
e9c5f95
13ac685
 
fb59cb8
 
 
 
 
 
a745ac1
fb59cb8
b689b5b
fb59cb8
 
 
e9c5f95
b689b5b
e9c5f95
 
fb59cb8
 
e9c5f95
b689b5b
e9c5f95
 
 
65ea629
e9c5f95
01b28b7
 
 
476290c
 
fb59cb8
a745ac1
 
13ac685
fb59cb8
b689b5b
fb59cb8
 
b689b5b
fb59cb8
 
bbe49e5
 
 
 
 
 
 
 
fb59cb8
bbe49e5
 
b689b5b
bbe49e5
fb59cb8
 
01b28b7
b689b5b
01b28b7
 
 
476290c
b689b5b
 
01b28b7
bbe49e5
b689b5b
 
 
 
 
 
 
bbe49e5
2e614a5
476290c
bbe49e5
b689b5b
bbe49e5
b689b5b
 
dc63b91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python

from __future__ import annotations

import os
import pathlib
import tarfile
import requests
from io import BytesIO
import deepdanbooru as dd
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import tensorflow as tf
import base64 

DESCRIPTION = "# [KichangKim/DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)"



def load_model() -> tf.keras.Model:
    path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "model-resnet_custom_v3.h5")
    model = tf.keras.models.load_model(path)
    return model


def load_labels() -> list[str]:
    path = huggingface_hub.hf_hub_download("public-data/DeepDanbooru", "tags.txt")
    with open(path) as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


model = load_model()
labels = load_labels()


def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
    _, height, width, _ = model.input_shape
    
    image = PIL.Image.open(BytesIO(base64.b64decode(base)))

    image = np.asarray(image)
    image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
    image = image.numpy()
    image = dd.image.transform_and_pad_image(image, width, height)
    image = image / 255.0
    probs = model.predict(image[None, ...])[0]
    probs = probs.astype(float)

    indices = np.argsort(probs)[::-1]
    result_all = dict()
    result_threshold = dict()
    for index in indices:
        label = labels[index]
        prob = probs[index]
        result_all[label] = prob
        if prob < score_threshold:
            break
        result_threshold[label] = prob
    result_text = ", ".join(result_all.keys())
    return result_threshold, result_all, result_text



with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Input", type="pil")
            score_threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.05, value=0.5)
            run_button = gr.Button("Run")
        with gr.Column():
            with gr.Tabs():
                with gr.Tab(label="Output"):
                    result = gr.Label(label="Output", show_label=False)
                with gr.Tab(label="JSON"):
                    result_json = gr.JSON(label="JSON output", show_label=False)
                with gr.Tab(label="Text"):
                    result_text = gr.Text(label="Text output", show_label=False, lines=5)
    
    run_button.click(
        fn=predict,
        inputs=[image, score_threshold],
        outputs=[result, result_json, result_text],
        api_name="predict",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch(show_error=True)