|
import json |
|
from functools import lru_cache |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from autofaiss import build_index |
|
from hfutils.operate import get_hf_fs |
|
from huggingface_hub import hf_hub_download |
|
from imgutils.data import load_image |
|
from imgutils.metrics import ccip_batch_extract_features, ccip_batch_differences, ccip_default_threshold |
|
|
|
SRC_REPO = 'deepghs/character_index' |
|
|
|
hf_fs = get_hf_fs() |
|
|
|
|
|
@lru_cache() |
|
def _make_index(): |
|
tag_infos = np.array(json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/index/tag_infos.json'))) |
|
embeddings = np.load(hf_hub_download( |
|
repo_id=SRC_REPO, |
|
repo_type='dataset', |
|
filename='index/embeddings.npy', |
|
)) |
|
index, index_infos = build_index(embeddings, save_on_disk=False) |
|
return (index, index_infos), tag_infos |
|
|
|
|
|
def gender_predict(p): |
|
if p['boy'] - p['girl'] >= 0.1: |
|
return 'male' |
|
elif p['girl'] - p['boy'] >= 0.1: |
|
return 'female' |
|
else: |
|
return 'not_sure' |
|
|
|
|
|
def query_character(image: Image.Image, count: int = 5, order_by: str = 'same_ratio', threshold: float = 0.7): |
|
(index, index_infos), tag_infos = _make_index() |
|
query = ccip_batch_extract_features([image]) |
|
assert query.shape == (1, 768) |
|
query = query / np.linalg.norm(query) |
|
all_dists, all_indices = index.search(query, k=count) |
|
dists, indices = all_dists[0], all_indices[0] |
|
|
|
images, records = {}, [] |
|
for dist, idx in zip(dists, indices): |
|
info = tag_infos[idx] |
|
current_image = load_image(hf_hub_download( |
|
repo_id=SRC_REPO, |
|
repo_type='dataset', |
|
filename=f'{info["hprefix"]}/{info["short_tag"]}/1.webp' |
|
)) |
|
feats = np.load(hf_hub_download( |
|
repo_id=SRC_REPO, |
|
repo_type='dataset', |
|
filename=f'{info["hprefix"]}/{info["short_tag"]}/feat.npy' |
|
)) |
|
diffs = ccip_batch_differences([query[0], *feats])[0, 1:] |
|
images[info['tag']] = current_image |
|
records.append({ |
|
'id': info['id'], |
|
'tag': info['tag'], |
|
'gender': gender_predict(info['gender']), |
|
'copyright': info['copyright'], |
|
'index_score': dist, |
|
'mean_diff': diffs.mean(), |
|
'same_ratio': (diffs < ccip_default_threshold()).mean(), |
|
}) |
|
|
|
df_records = pd.DataFrame(records) |
|
df_records = df_records.sort_values( |
|
by=[order_by, 'index_score'] if order_by != 'index_score' else ['index_score'], |
|
ascending=[False, False] if order_by != 'index_score' else [False], |
|
) |
|
df_records = df_records[df_records[order_by] >= threshold] |
|
ret_images = [] |
|
for row_item in df_records.to_dict('records'): |
|
ret_images.append((images[row_item['tag']], f'{row_item["tag"]} ({row_item[order_by]:.3f})')) |
|
return ret_images, df_records |
|
|