|
from __future__ import annotations |
|
|
|
import argparse |
|
import functools |
|
import html |
|
import os |
|
|
|
import gradio as gr |
|
import huggingface_hub |
|
import numpy as np |
|
import onnxruntime as rt |
|
import pandas as pd |
|
import piexif |
|
import piexif.helper |
|
import PIL.Image |
|
|
|
from Utils import dbimutils |
|
|
|
TITLE = "WaifuDiffusion v1.4 Tags" |
|
DESCRIPTION = """ |
|
Demo for: |
|
- [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2) |
|
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) |
|
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) |
|
- [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2) |
|
- [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2) |
|
|
|
Includes "ready to copy" prompt and a prompt analyzer. |
|
|
|
Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string) |
|
Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru) |
|
|
|
PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) |
|
|
|
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085) |
|
""" |
|
|
|
HF_TOKEN = os.environ["HF_TOKEN"] |
|
MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" |
|
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" |
|
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" |
|
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" |
|
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" |
|
MODEL_FILENAME = "model.onnx" |
|
LABEL_FILENAME = "selected_tags.csv" |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--score-slider-step", type=float, default=0.05) |
|
parser.add_argument("--score-general-threshold", type=float, default=0.35) |
|
parser.add_argument("--score-character-threshold", type=float, default=0.85) |
|
parser.add_argument("--share", action="store_true") |
|
return parser.parse_args() |
|
|
|
|
|
def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession: |
|
path = huggingface_hub.hf_hub_download( |
|
model_repo, model_filename, use_auth_token=HF_TOKEN |
|
) |
|
model = rt.InferenceSession(path) |
|
return model |
|
|
|
|
|
def change_model(model_name): |
|
global loaded_models |
|
|
|
if model_name == "MOAT": |
|
model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME) |
|
elif model_name == "SwinV2": |
|
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME) |
|
elif model_name == "ConvNext": |
|
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME) |
|
elif model_name == "ConvNextV2": |
|
model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME) |
|
elif model_name == "ViT": |
|
model = load_model(VIT_MODEL_REPO, MODEL_FILENAME) |
|
|
|
loaded_models[model_name] = model |
|
return loaded_models[model_name] |
|
|
|
|
|
def load_labels() -> list[str]: |
|
path = huggingface_hub.hf_hub_download( |
|
MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN |
|
) |
|
df = pd.read_csv(path) |
|
|
|
tag_names = df["name"].tolist() |
|
rating_indexes = list(np.where(df["category"] == 9)[0]) |
|
general_indexes = list(np.where(df["category"] == 0)[0]) |
|
character_indexes = list(np.where(df["category"] == 4)[0]) |
|
return tag_names, rating_indexes, general_indexes, character_indexes |
|
|
|
|
|
def plaintext_to_html(text): |
|
text = ( |
|
"<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>" |
|
) |
|
return text |
|
|
|
|
|
def predict( |
|
image: PIL.Image.Image, |
|
model_name: str, |
|
general_threshold: float, |
|
character_threshold: float, |
|
tag_names: list[str], |
|
rating_indexes: list[np.int64], |
|
general_indexes: list[np.int64], |
|
character_indexes: list[np.int64], |
|
): |
|
global loaded_models |
|
|
|
rawimage = image |
|
|
|
model = loaded_models[model_name] |
|
if model is None: |
|
model = change_model(model_name) |
|
|
|
_, height, width, _ = model.get_inputs()[0].shape |
|
|
|
|
|
image = image.convert("RGBA") |
|
new_image = PIL.Image.new("RGBA", image.size, "WHITE") |
|
new_image.paste(image, mask=image) |
|
image = new_image.convert("RGB") |
|
image = np.asarray(image) |
|
|
|
|
|
image = image[:, :, ::-1] |
|
|
|
image = dbimutils.make_square(image, height) |
|
image = dbimutils.smart_resize(image, height) |
|
image = image.astype(np.float32) |
|
image = np.expand_dims(image, 0) |
|
|
|
input_name = model.get_inputs()[0].name |
|
label_name = model.get_outputs()[0].name |
|
probs = model.run([label_name], {input_name: image})[0] |
|
|
|
labels = list(zip(tag_names, probs[0].astype(float))) |
|
|
|
|
|
ratings_names = [labels[i] for i in rating_indexes] |
|
rating = dict(ratings_names) |
|
|
|
|
|
general_names = [labels[i] for i in general_indexes] |
|
general_res = [x for x in general_names if x[1] > general_threshold] |
|
general_res = dict(general_res) |
|
|
|
|
|
character_names = [labels[i] for i in character_indexes] |
|
character_res = [x for x in character_names if x[1] > character_threshold] |
|
character_res = dict(character_res) |
|
|
|
b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True)) |
|
a = ( |
|
", ".join(list(b.keys())) |
|
.replace("_", " ") |
|
.replace("(", "\(") |
|
.replace(")", "\)") |
|
) |
|
c = ", ".join(list(b.keys())) |
|
|
|
items = rawimage.info |
|
geninfo = "" |
|
|
|
if "exif" in rawimage.info: |
|
exif = piexif.load(rawimage.info["exif"]) |
|
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"") |
|
try: |
|
exif_comment = piexif.helper.UserComment.load(exif_comment) |
|
except ValueError: |
|
exif_comment = exif_comment.decode("utf8", errors="ignore") |
|
|
|
items["exif comment"] = exif_comment |
|
geninfo = exif_comment |
|
|
|
for field in [ |
|
"jfif", |
|
"jfif_version", |
|
"jfif_unit", |
|
"jfif_density", |
|
"dpi", |
|
"exif", |
|
"loop", |
|
"background", |
|
"timestamp", |
|
"duration", |
|
]: |
|
items.pop(field, None) |
|
|
|
geninfo = items.get("parameters", geninfo) |
|
|
|
info = f""" |
|
<p><h4>PNG Info</h4></p> |
|
""" |
|
for key, text in items.items(): |
|
info += ( |
|
f""" |
|
<div> |
|
<p><b>{plaintext_to_html(str(key))}</b></p> |
|
<p>{plaintext_to_html(str(text))}</p> |
|
</div> |
|
""".strip() |
|
+ "\n" |
|
) |
|
|
|
if len(info) == 0: |
|
message = "Nothing found in the image." |
|
info = f"<div><p>{message}<p></div>" |
|
|
|
return (a, c, rating, character_res, general_res, info) |
|
|
|
|
|
def main(): |
|
global loaded_models |
|
loaded_models = { |
|
"MOAT": None, |
|
"SwinV2": None, |
|
"ConvNext": None, |
|
"ConvNextV2": None, |
|
"ViT": None, |
|
} |
|
|
|
args = parse_args() |
|
|
|
change_model("MOAT") |
|
|
|
tag_names, rating_indexes, general_indexes, character_indexes = load_labels() |
|
|
|
func = functools.partial( |
|
predict, |
|
tag_names=tag_names, |
|
rating_indexes=rating_indexes, |
|
general_indexes=general_indexes, |
|
character_indexes=character_indexes, |
|
) |
|
|
|
gr.Interface( |
|
fn=func, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input"), |
|
gr.Radio( |
|
["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"], |
|
value="MOAT", |
|
label="Model", |
|
), |
|
gr.Slider( |
|
0, |
|
1, |
|
step=args.score_slider_step, |
|
value=args.score_general_threshold, |
|
label="General Tags Threshold", |
|
), |
|
gr.Slider( |
|
0, |
|
1, |
|
step=args.score_slider_step, |
|
value=args.score_character_threshold, |
|
label="Character Tags Threshold", |
|
), |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Output (string)"), |
|
gr.Textbox(label="Output (raw string)"), |
|
gr.Label(label="Rating"), |
|
gr.Label(label="Output (characters)"), |
|
gr.Label(label="Output (tags)"), |
|
gr.HTML(), |
|
], |
|
examples=[["power.jpg", "MOAT", 0.35, 0.85]], |
|
title=TITLE, |
|
description=DESCRIPTION, |
|
allow_flagging="never", |
|
).launch( |
|
enable_queue=True, |
|
share=args.share, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|