#!/usr/bin/env python from __future__ import annotations import argparse import functools import io import os import pathlib import tarfile import deepdanbooru as dd import gradio as gr import huggingface_hub import numpy as np import PIL.Image import tensorflow as tf from huggingface_hub import hf_hub_download TITLE = 'TADNE Image Search with DeepDanbooru' DESCRIPTION = '''The original TADNE site is https://thisanimedoesnotexist.ai/. This app shows images similar to the query image from images generated by the TADNE model with seed 0-99999. Here, image similarity is measured by the L2 distance of the intermediate features by the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) model. Known issues: - The `Seed` table in the output doesn't refresh properly in gradio 2.9.1. https://github.com/gradio-app/gradio/issues/921 ''' ARTICLE = None TOKEN = os.environ['TOKEN'] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--theme', type=str, default='dark-grass') parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') parser.add_argument('--allow-screenshot', action='store_true') return parser.parse_args() def download_image_tarball(size: int, dirname: str) -> pathlib.Path: path = hf_hub_download('hysts/TADNE-sample-images', f'{size}/{dirname}.tar', repo_type='dataset', use_auth_token=TOKEN) return path def load_deepdanbooru_predictions(dirname: str) -> np.ndarray: path = hf_hub_download( 'hysts/TADNE-sample-images', f'prediction_results/deepdanbooru/intermediate_features/{dirname}.npy', repo_type='dataset', use_auth_token=TOKEN) return np.load(path) def load_sample_image_paths() -> list[pathlib.Path]: image_dir = pathlib.Path('images') if not image_dir.exists(): dataset_repo = 'hysts/sample-images-TADNE' path = huggingface_hub.hf_hub_download(dataset_repo, 'images.tar.gz', repo_type='dataset', use_auth_token=TOKEN) with tarfile.open(path) as f: f.extractall() return sorted(image_dir.glob('*')) def create_model() -> tf.keras.Model: path = huggingface_hub.hf_hub_download('hysts/DeepDanbooru', 'model-resnet_custom_v3.h5', use_auth_token=TOKEN) model = tf.keras.models.load_model(path) model = tf.keras.Model(model.input, model.layers[-4].output) layer = tf.keras.layers.GlobalAveragePooling2D() model = tf.keras.Sequential([model, layer]) return model def predict(image: PIL.Image.Image, model: tf.keras.Model) -> np.ndarray: _, height, width, _ = model.input_shape image = np.asarray(image) image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True) image = image.numpy() image = dd.image.transform_and_pad_image(image, width, height) image = image / 255. features = model.predict(image[None, ...])[0] features = features.astype(float) return features def run( image: PIL.Image.Image, nrows: int, ncols: int, image_size: int, dirname: str, tarball_path: pathlib.Path, deepdanbooru_predictions: np.ndarray, model: tf.keras.Model, ) -> tuple[np.ndarray, np.ndarray]: features = predict(image, model) distances = ((deepdanbooru_predictions - features)**2).sum(axis=1) image_indices = np.argsort(distances) seeds = [] images = [] with tarfile.TarFile(tarball_path) as tar_file: for index in range(nrows * ncols): image_index = image_indices[index] seeds.append(image_index) member = tar_file.getmember(f'{dirname}/{image_index:07d}.jpg') with tar_file.extractfile(member) as f: data = io.BytesIO(f.read()) image = PIL.Image.open(data) image = np.asarray(image) images.append(image) res = np.asarray(images).reshape(nrows, ncols, image_size, image_size, 3).transpose(0, 2, 1, 3, 4).reshape( nrows * image_size, ncols * image_size, 3) seeds = np.asarray(seeds).reshape(nrows, ncols) seed_text = ', '.join(list(map(str, seeds.ravel().tolist()))) return res, seeds, seed_text def main(): args = parse_args() image_size = 128 dirname = '0-99999' tarball_path = download_image_tarball(image_size, dirname) deepdanbooru_predictions = load_deepdanbooru_predictions(dirname) model = create_model() image_paths = load_sample_image_paths() examples = [[path.as_posix(), 2, 5] for path in image_paths] func = functools.partial( run, image_size=image_size, dirname=dirname, tarball_path=tarball_path, deepdanbooru_predictions=deepdanbooru_predictions, model=model, ) func = functools.update_wrapper(func, run) gr.Interface( func, [ gr.inputs.Image(type='pil', label='Input'), gr.inputs.Slider(1, 10, step=1, default=2, label='Number of Rows'), gr.inputs.Slider( 1, 10, step=1, default=5, label='Number of Columns'), ], [ gr.outputs.Image(type='numpy', label='Output'), gr.outputs.Dataframe(type='numpy', label='Seed'), gr.outputs.Textbox(label='Seed (text)'), ], examples=examples, title=TITLE, description=DESCRIPTION, article=ARTICLE, theme=args.theme, allow_screenshot=args.allow_screenshot, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()