import json import os from collections import defaultdict from functools import lru_cache from typing import List, Dict import faiss import gradio as gr import numpy as np from PIL import Image from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \ KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool from hfutils.operate import get_hf_fs, get_hf_client from hfutils.utils import TemporaryDirectory from imgutils.tagging import wd14 _REPO_ID = 'deepghs/anime_sites_indices' hf_fs = get_hf_fs() hf_client = get_hf_client() _DEFAULT_MODEL_NAME = 'SwinV2_v3_dgzyka_23325111_8GB' _ALL_MODEL_NAMES = [ os.path.dirname(os.path.relpath(path, _REPO_ID)) for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') ] _SITE_CLS = { 'danbooru': DanbooruNewestWebpDataPool, 'yandere': YandeWebpDataPool, 'zerochan': ZerochanWebpDataPool, 'gelbooru': GelbooruWebpDataPool, 'konachan': KonachanWebpDataPool, 'anime_pictures': AnimePicturesWebpDataPool, 'rule34': Rule34WebpDataPool, } def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: with TemporaryDirectory() as td: datapool = _SITE_CLS[site_name]() datapool.batch_download_to_directory( resource_ids=ids, dst_dir=td, ) retval = {} for file in os.listdir(td): id_ = int(os.path.splitext(file)[0]) image = Image.open(os.path.join(td, file)) image.load() retval[id_] = image return retval def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: _sites = defaultdict(list) for id_ in ids: site_name, num_id = id_.rsplit('_', maxsplit=1) num_id = int(num_id) _sites[site_name].append(num_id) _retval = {} for site_name, site_ids in _sites.items(): _retval.update({ f'{site_name}_{id_}': image for id_, image in _get_from_ids(site_name, site_ids).items() }) return _retval @lru_cache(maxsize=3) def _get_index_info(repo_id: str, model_name: str): image_ids = np.load(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/ids.npy', )) knn_index = faiss.read_index(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/knn.index', )) config = json.loads(open(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/infos.json', )).read())["index_param"] faiss.ParameterSpace().set_index_parameters(knn_index, config) return image_ids, knn_index def search(model_name: str, img_input, n_neighbours: int): images_ids, knn_index = _get_index_info(_REPO_ID, model_name) embeddings = wd14.get_wd14_tags( img_input, model_name="SwinV2_v3", fmt="embedding", ) embeddings = np.expand_dims(embeddings, 0) faiss.normalize_L2(embeddings) dists, indexes = knn_index.search(embeddings, k=n_neighbours) neighbours_ids = images_ids[indexes][0] captions = [] images = [] ids_to_images = _get_from_raw_ids(neighbours_ids) for image_id, dist in zip(neighbours_ids, dists[0]): if image_id in ids_to_images: images.append(ids_to_images[image_id]) captions.append(f"{image_id}/{dist:.2f}") return list(zip(images, captions)) if __name__ == "__main__": with gr.Blocks() as demo: with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Input") with gr.Column(): with gr.Row(): n_model = gr.Dropdown( choices=_ALL_MODEL_NAMES, value=_DEFAULT_MODEL_NAME, label='Index to Use', ) with gr.Row(): n_neighbours = gr.Slider( minimum=1, maximum=50, value=20, step=1, label="# of images", ) find_btn = gr.Button("Find similar images") with gr.Row(): similar_images = gr.Gallery(label="Similar images", columns=[5]) find_btn.click( fn=search, inputs=[ n_model, img_input, n_neighbours, ], outputs=[similar_images], ) demo.queue().launch()