|
import math |
|
from dataclasses import dataclass |
|
from functools import lru_cache |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.utils import HfHubHTTPError |
|
from PIL import Image |
|
from torch import Tensor, nn |
|
|
|
|
|
@dataclass |
|
class Heatmap: |
|
label: str |
|
score: float |
|
image: Image.Image |
|
|
|
|
|
@dataclass |
|
class LabelData: |
|
names: list[str] |
|
rating: list[np.int64] |
|
general: list[np.int64] |
|
character: list[np.int64] |
|
|
|
|
|
@dataclass |
|
class ImageLabels: |
|
caption: str |
|
booru: str |
|
rating: dict[str, float] |
|
general: dict[str, float] |
|
character: dict[str, float] |
|
|
|
|
|
@lru_cache(maxsize=5) |
|
def load_labels_hf( |
|
repo_id: str, |
|
revision: Optional[str] = None, |
|
token: Optional[str] = None, |
|
) -> LabelData: |
|
try: |
|
csv_path = hf_hub_download( |
|
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token |
|
) |
|
csv_path = Path(csv_path).resolve() |
|
except HfHubHTTPError as e: |
|
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e |
|
|
|
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) |
|
tag_data = LabelData( |
|
names=df["name"].tolist(), |
|
rating=list(np.where(df["category"] == 9)[0]), |
|
general=list(np.where(df["category"] == 0)[0]), |
|
character=list(np.where(df["category"] == 4)[0]), |
|
) |
|
|
|
return tag_data |
|
|
|
|
|
def mcut_threshold(probs: np.ndarray) -> float: |
|
""" |
|
Maximum Cut Thresholding (MCut) |
|
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy |
|
for Multi-label Classification. In 11th International Symposium, IDA 2012 |
|
(pp. 172-183). |
|
""" |
|
probs = probs[probs.argsort()[::-1]] |
|
diffs = probs[:-1] - probs[1:] |
|
idx = diffs.argmax() |
|
thresh = (probs[idx] + probs[idx + 1]) / 2 |
|
return float(thresh) |
|
|
|
|
|
def pil_ensure_rgb(image: Image.Image) -> Image.Image: |
|
|
|
if image.mode not in ["RGB", "RGBA"]: |
|
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") |
|
|
|
if image.mode == "RGBA": |
|
canvas = Image.new("RGBA", image.size, (255, 255, 255)) |
|
canvas.alpha_composite(image) |
|
image = canvas.convert("RGB") |
|
return image |
|
|
|
|
|
def pil_pad_square( |
|
image: Image.Image, |
|
fill: tuple[int, int, int] = (255, 255, 255), |
|
) -> Image.Image: |
|
w, h = image.size |
|
|
|
px = max(image.size) |
|
|
|
canvas = Image.new("RGB", (px, px), fill) |
|
canvas.paste(image, ((px - w) // 2, (px - h) // 2)) |
|
return canvas |
|
|
|
|
|
def preprocess_image( |
|
image: Image.Image, |
|
size_px: int | tuple[int, int], |
|
upscale: bool = True, |
|
) -> Image.Image: |
|
""" |
|
Preprocess an image to be square and centered on a white background. |
|
""" |
|
if isinstance(size_px, int): |
|
size_px = (size_px, size_px) |
|
|
|
|
|
image = pil_ensure_rgb(image) |
|
image = pil_pad_square(image) |
|
|
|
|
|
if image.size[0] < size_px[0] or image.size[1] < size_px[1]: |
|
if upscale is False: |
|
raise ValueError("Image is smaller than target size, and upscaling is disabled") |
|
image = image.resize(size_px, Image.LANCZOS) |
|
if image.size[0] > size_px[0] or image.size[1] > size_px[1]: |
|
image.thumbnail(size_px, Image.BICUBIC) |
|
|
|
return image |
|
|
|
|
|
def pil_make_grid( |
|
images: list[Image.Image], |
|
max_cols: int = 8, |
|
padding: int = 4, |
|
bg_color: tuple[int, int, int] = (40, 42, 54), |
|
partial_rows: bool = True, |
|
) -> Image.Image: |
|
n_cols = min(math.floor(math.sqrt(len(images))), max_cols) |
|
n_rows = math.ceil(len(images) / n_cols) |
|
|
|
|
|
if n_cols * n_rows > len(images) and not partial_rows: |
|
n_rows -= 1 |
|
|
|
|
|
image_width, image_height = images[0].size |
|
|
|
canvas_width = ((image_width + padding) * n_cols) + padding |
|
canvas_height = ((image_height + padding) * n_rows) + padding |
|
|
|
canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color) |
|
for i, img in enumerate(images): |
|
x = (i % n_cols) * (image_width + padding) + padding |
|
y = (i // n_cols) * (image_height + padding) + padding |
|
canvas.paste(img, (x, y)) |
|
|
|
return canvas |
|
|
|
|
|
|
|
kaomojis = [ |
|
"0_0", |
|
"(o)_(o)", |
|
"+_+", |
|
"+_-", |
|
"._.", |
|
"<o>_<o>", |
|
"<|>_<|>", |
|
"=_=", |
|
">_<", |
|
"3_3", |
|
"6_9", |
|
">_o", |
|
"@_@", |
|
"^_^", |
|
"o_o", |
|
"u_u", |
|
"x_x", |
|
"|_|", |
|
"||_||", |
|
] |
|
|