|
import gradio as gr |
|
import huggingface_hub |
|
from PIL import Image |
|
from pathlib import Path |
|
import onnxruntime as rt |
|
import numpy as np |
|
import csv |
|
import spaces |
|
|
|
import onnxruntime as rt |
|
e621_model_path = Path(huggingface_hub.snapshot_download('toynya/Z3D-E621-Convnext')) |
|
e621_model_session = rt.InferenceSession(e621_model_path / 'model.onnx', providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) |
|
with open(e621_model_path / 'tags-selected.csv', mode='r', encoding='utf-8') as file: |
|
csv_reader = csv.DictReader(file) |
|
e621_model_tags = [row['name'].strip() for row in csv_reader] |
|
|
|
|
|
def prepare_image_e621(image: Image.Image, target_size: int): |
|
import numpy as np |
|
|
|
image_shape = image.size |
|
max_dim = max(image_shape) |
|
pad_left = (max_dim - image_shape[0]) // 2 |
|
pad_top = (max_dim - image_shape[1]) // 2 |
|
|
|
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) |
|
padded_image.paste(image, (pad_left, pad_top)) |
|
|
|
|
|
if max_dim != target_size: |
|
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) |
|
|
|
|
|
|
|
image_array = np.asarray(padded_image, dtype=np.float32) |
|
|
|
|
|
image_array = image_array[:, :, ::-1] |
|
|
|
return np.expand_dims(image_array, axis=0) |
|
|
|
|
|
def predict_e621(image: Image.Image): |
|
THRESHOLD = 0.3 |
|
image_array = prepare_image_e621(image, 448) |
|
|
|
image_array = prepare_image_e621(image, 448) |
|
input_name = 'input_1:0' |
|
output_name = 'predictions_sigmoid' |
|
|
|
result = e621_model_session.run([output_name], {input_name: image_array}) |
|
result = result[0][0] |
|
|
|
scores = {e621_model_tags[i]: result[i] for i in range(len(result))} |
|
predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD] |
|
tag_string = ', '.join(predicted_tags).replace("_", " ") |
|
|
|
return tag_string, scores |
|
|
|
|
|
DESCRIPTION = """ |
|
E621 Tagger (Z3D-E621-Convnext) |
|
- Image => E621 Pony Prompt |
|
- Mod of [fancyfeast's demo](https://huggingface.co/spaces/fancyfeast/Z3D-E621-Convnext-space) for toynya's [Z3D-E621-Convnext](https://huggingface.co/toynya/Z3D-E621-Convnext) |
|
""" |
|
|
|
gradio_app = gr.Interface( |
|
predict_e621, |
|
inputs=gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil'), |
|
outputs=[ |
|
gr.Textbox(label="Tag String", show_copy_button=True), |
|
gr.Label(label="Tag Predictions", num_top_classes=100), |
|
], |
|
description=DESCRIPTION, |
|
allow_flagging="never", |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
gradio_app.launch() |
|
|