hysts's picture
hysts HF staff
Clean up
5b24937
raw
history blame
6.52 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import functools
import io
import os
import pathlib
import tarfile
import deepdanbooru as dd
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import tensorflow as tf
from huggingface_hub import hf_hub_download
TITLE = 'TADNE Image Search with DeepDanbooru'
DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/.
This app shows images similar to the query image from images generated
by the TADNE model with seed 0-99999.
Here, image similarity is measured by the L2 distance of the intermediate
features by the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru)
model.
Known issues:
- The `Seed` table in the output doesn't refresh properly in gradio 2.9.1.
https://github.com/gradio-app/gradio/issues/921
'''
ARTICLE = None
TOKEN = os.environ['TOKEN']
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--theme', type=str, default='dark-grass')
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
parser.add_argument('--allow-screenshot', action='store_true')
return parser.parse_args()
def download_image_tarball(size: int, dirname: str) -> pathlib.Path:
path = hf_hub_download('hysts/TADNE-sample-images',
f'{size}/{dirname}.tar',
repo_type='dataset',
use_auth_token=TOKEN)
return path
def load_deepdanbooru_predictions(dirname: str) -> np.ndarray:
path = hf_hub_download(
'hysts/TADNE-sample-images',
f'prediction_results/deepdanbooru/intermediate_features/{dirname}.npy',
repo_type='dataset',
use_auth_token=TOKEN)
return np.load(path)
def load_sample_image_paths() -> list[pathlib.Path]:
image_dir = pathlib.Path('images')
if not image_dir.exists():
dataset_repo = 'hysts/sample-images-TADNE'
path = huggingface_hub.hf_hub_download(dataset_repo,
'images.tar.gz',
repo_type='dataset',
use_auth_token=TOKEN)
with tarfile.open(path) as f:
f.extractall()
return sorted(image_dir.glob('*'))
def create_model() -> tf.keras.Model:
path = huggingface_hub.hf_hub_download('hysts/DeepDanbooru',
'model-resnet_custom_v3.h5',
use_auth_token=TOKEN)
model = tf.keras.models.load_model(path)
model = tf.keras.Model(model.input, model.layers[-4].output)
layer = tf.keras.layers.GlobalAveragePooling2D()
model = tf.keras.Sequential([model, layer])
return model
def predict(image: PIL.Image.Image, model: tf.keras.Model) -> np.ndarray:
_, height, width, _ = model.input_shape
image = np.asarray(image)
image = tf.image.resize(image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True)
image = image.numpy()
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.
features = model.predict(image[None, ...])[0]
features = features.astype(float)
return features
def run(
image: PIL.Image.Image,
nrows: int,
ncols: int,
image_size: int,
dirname: str,
tarball_path: pathlib.Path,
deepdanbooru_predictions: np.ndarray,
model: tf.keras.Model,
) -> tuple[np.ndarray, np.ndarray]:
features = predict(image, model)
distances = ((deepdanbooru_predictions - features)**2).sum(axis=1)
image_indices = np.argsort(distances)
seeds = []
images = []
with tarfile.TarFile(tarball_path) as tar_file:
for index in range(nrows * ncols):
image_index = image_indices[index]
seeds.append(image_index)
member = tar_file.getmember(f'{dirname}/{image_index:07d}.jpg')
with tar_file.extractfile(member) as f:
data = io.BytesIO(f.read())
image = PIL.Image.open(data)
image = np.asarray(image)
images.append(image)
res = np.asarray(images).reshape(nrows, ncols, image_size, image_size,
3).transpose(0, 2, 1, 3, 4).reshape(
nrows * image_size,
ncols * image_size, 3)
seeds = np.asarray(seeds).reshape(nrows, ncols)
seed_text = ', '.join(list(map(str, seeds.ravel().tolist())))
return res, seeds, seed_text
def main():
args = parse_args()
image_size = 128
dirname = '0-99999'
tarball_path = download_image_tarball(image_size, dirname)
deepdanbooru_predictions = load_deepdanbooru_predictions(dirname)
model = create_model()
image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 2, 5] for path in image_paths]
func = functools.partial(
run,
image_size=image_size,
dirname=dirname,
tarball_path=tarball_path,
deepdanbooru_predictions=deepdanbooru_predictions,
model=model,
)
func = functools.update_wrapper(func, run)
gr.Interface(
func,
[
gr.inputs.Image(type='pil', label='Input'),
gr.inputs.Slider(1, 10, step=1, default=2, label='Number of Rows'),
gr.inputs.Slider(
1, 10, step=1, default=5, label='Number of Columns'),
],
[
gr.outputs.Image(type='numpy', label='Output'),
gr.outputs.Dataframe(type='numpy', label='Seed'),
gr.outputs.Textbox(label='Seed (text)'),
],
examples=examples,
title=TITLE,
description=DESCRIPTION,
article=ARTICLE,
theme=args.theme,
allow_screenshot=args.allow_screenshot,
allow_flagging=args.allow_flagging,
live=args.live,
).launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()