|
import math |
|
from pathlib import Path |
|
|
|
import colorcet as cc |
|
import cv2 |
|
import numpy as np |
|
import timm |
|
import torch |
|
from PIL import Image |
|
from matplotlib.colors import LinearSegmentedColormap |
|
from timm.data import create_transform, resolve_data_config |
|
from timm.models import VisionTransformer |
|
from torch import Tensor, nn |
|
from torch.nn import functional as F |
|
from torchvision import transforms as T |
|
|
|
from .common import Heatmap, ImageLabels, LabelData, pil_make_grid |
|
|
|
|
|
work_dir = (Path(__file__).parent if "__file__" in locals() else Path.cwd()).resolve() |
|
temp_dir = work_dir.joinpath("temp") |
|
temp_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
model_cache: dict[str, VisionTransformer] = {} |
|
transform_cache: dict[str, T.Compose] = {} |
|
|
|
|
|
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class RGBtoBGR(nn.Module): |
|
def forward(self, x: Tensor) -> Tensor: |
|
if x.ndim == 4: |
|
return x[:, [2, 1, 0], :, :] |
|
return x[[2, 1, 0], :, :] |
|
|
|
|
|
def model_device(model: nn.Module) -> torch.device: |
|
return next(model.parameters()).device |
|
|
|
|
|
def load_model(repo_id: str) -> VisionTransformer: |
|
global model_cache |
|
|
|
if model_cache.get(repo_id, None) is None: |
|
|
|
model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval().to(torch_device) |
|
|
|
return model_cache[repo_id] |
|
|
|
|
|
def load_model_and_transform(repo_id: str) -> tuple[VisionTransformer, T.Compose]: |
|
global transform_cache |
|
global model_cache |
|
|
|
if model_cache.get(repo_id, None) is None: |
|
|
|
model_cache[repo_id] = timm.create_model("hf-hub:" + repo_id, pretrained=True).eval() |
|
model = model_cache[repo_id] |
|
|
|
if transform_cache.get(repo_id, None) is None: |
|
transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) |
|
|
|
transform_cache[repo_id] = T.Compose(transforms.transforms + [RGBtoBGR()]) |
|
transform = transform_cache[repo_id] |
|
|
|
return model, transform |
|
|
|
|
|
def get_tags( |
|
probs: Tensor, |
|
labels: LabelData, |
|
gen_threshold: float, |
|
char_threshold: float, |
|
): |
|
|
|
probs = list(zip(labels.names, probs.numpy())) |
|
|
|
|
|
rating_labels = dict([probs[i] for i in labels.rating]) |
|
|
|
|
|
gen_labels = [probs[i] for i in labels.general] |
|
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) |
|
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
char_labels = [probs[i] for i in labels.character] |
|
char_labels = dict([x for x in char_labels if x[1] > char_threshold]) |
|
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
combined_names = [x for x in gen_labels] |
|
combined_names.extend([x for x in char_labels]) |
|
|
|
|
|
caption = ", ".join(combined_names).replace("(", "\(").replace(")", "\)") |
|
booru = caption.replace("_", " ") |
|
|
|
return caption, booru, rating_labels, char_labels, gen_labels |
|
|
|
|
|
@torch.no_grad() |
|
def render_heatmap( |
|
image: Tensor, |
|
gradients: Tensor, |
|
image_feats: Tensor, |
|
image_probs: Tensor, |
|
image_labels: list[str], |
|
cmap: LinearSegmentedColormap = cc.m_linear_bmy_10_95_c71, |
|
pos_embed_dim: int = 784, |
|
image_size: tuple[int, int] = (448, 448), |
|
font_args: dict = { |
|
"fontFace": cv2.FONT_HERSHEY_SIMPLEX, |
|
"fontScale": 1, |
|
"color": (255, 255, 255), |
|
"thickness": 2, |
|
"lineType": cv2.LINE_AA, |
|
}, |
|
partial_rows: bool = True, |
|
) -> tuple[list[Heatmap], Image.Image]: |
|
|
|
|
|
image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze() |
|
hmap_dim = int(math.sqrt(image_hmaps.mean(-1).numel() / len(image_labels))) |
|
image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), -1) |
|
image_hmaps = image_hmaps[..., -hmap_dim ** 2:] |
|
image_hmaps = image_hmaps.reshape(len(image_labels), hmap_dim, hmap_dim) |
|
image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps)) |
|
|
|
image_hmaps /= image_hmaps.reshape(image_hmaps.shape[0], -1).max(-1)[0].unsqueeze(-1).unsqueeze(-1) |
|
|
|
image_hmaps = torch.stack([(x - x.min()) / (x.max() - x.min()) for x in image_hmaps]).unsqueeze(1) |
|
|
|
image_hmaps = F.interpolate(image_hmaps, size=image_size, mode="bilinear").squeeze(1) |
|
|
|
hmap_imgs: list[Heatmap] = [] |
|
for tag, hmap, score in zip(image_labels, image_hmaps, image_probs.cpu()): |
|
image_pixels = image.add(1).mul(127.5).squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) |
|
hmap_pixels = cmap(hmap.cpu().numpy(), bytes=True)[:, :, :3] |
|
|
|
hmap_cv2 = cv2.cvtColor(hmap_pixels, cv2.COLOR_RGB2BGR) |
|
hmap_image = cv2.addWeighted(image_pixels, 0.5, hmap_cv2, 0.5, 0) |
|
if tag is not None: |
|
cv2.putText(hmap_image, tag, (10, 30), **font_args) |
|
cv2.putText(hmap_image, f"{score:.3f}", org=(10, 60), **font_args) |
|
|
|
hmap_pil = Image.fromarray(cv2.cvtColor(hmap_image, cv2.COLOR_BGR2RGB)) |
|
hmap_imgs.append(Heatmap(tag, score.item(), hmap_pil)) |
|
|
|
hmap_imgs = sorted(hmap_imgs, key=lambda x: x.score, reverse=True) |
|
hmap_grid = pil_make_grid([x.image for x in hmap_imgs], partial_rows=partial_rows) |
|
|
|
return hmap_imgs, hmap_grid |
|
|
|
|
|
def process_heatmap( |
|
model: VisionTransformer, |
|
image: Tensor, |
|
labels: LabelData, |
|
threshold: float = 0.5, |
|
partial_rows: bool = True, |
|
) -> tuple[list[tuple[float, str, Image.Image]], Image.Image, ImageLabels]: |
|
torch_device = model_device(model) |
|
|
|
with torch.set_grad_enabled(True): |
|
features = model.forward_features(image.to(torch_device)) |
|
probs = model.forward_head(features) |
|
probs = F.sigmoid(probs).squeeze(0) |
|
|
|
probs_mask = probs > threshold |
|
heatmap_probs = probs[probs_mask] |
|
|
|
label_indices = torch.nonzero(probs_mask, as_tuple=False).squeeze(1) |
|
image_labels = [labels.names[label_indices[i]] for i in range(len(label_indices))] |
|
|
|
eye = torch.eye(heatmap_probs.shape[0], device=torch_device) |
|
grads = torch.autograd.grad( |
|
outputs=heatmap_probs, |
|
inputs=features, |
|
grad_outputs=eye, |
|
is_grads_batched=True, |
|
retain_graph=True, |
|
) |
|
grads = grads[0].detach().requires_grad_(False)[:, 0, :, :].unsqueeze(1) |
|
|
|
with torch.set_grad_enabled(False): |
|
hmap_imgs, hmap_grid = render_heatmap( |
|
image=image, |
|
gradients=grads, |
|
image_feats=features, |
|
image_probs=heatmap_probs, |
|
image_labels=image_labels, |
|
partial_rows=partial_rows, |
|
) |
|
|
|
caption, booru, ratings, character, general = get_tags( |
|
probs=probs.cpu(), |
|
labels=labels, |
|
gen_threshold=threshold, |
|
char_threshold=threshold, |
|
) |
|
labels = ImageLabels(caption, booru, ratings, general, character) |
|
|
|
return hmap_imgs, hmap_grid, labels |
|
|