ma-images / app.py
broadwell's picture
Allow gradient viz for J-CLIP and ResNet, disable for M-CLIP
810de2d verified
# from base64 import b64encode
from io import BytesIO
from math import ceil
import clip
from multilingual_clip import legacy_multilingual_clip, pt_multilingual_clip
import numpy as np
import pandas as pd
from PIL import Image
import requests
import streamlit as st
import torch
from torchvision.transforms import ToPILImage
from transformers import AutoTokenizer, AutoModel, BertTokenizer
from CLIP_Explainability.clip_ import load, tokenize
from CLIP_Explainability.rn_cam import (
# interpret_rn,
interpret_rn_overlapped,
rn_perword_relevance,
)
from CLIP_Explainability.vit_cam import (
# interpret_vit,
vit_perword_relevance,
interpret_vit_overlapped,
)
from pytorch_grad_cam.grad_cam import GradCAM
RUN_LITE = True # Load models for CAM viz for M-CLIP and J-CLIP only
MAX_IMG_WIDTH = 500
MAX_IMG_HEIGHT = 800
st.set_page_config(layout="wide")
# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
def find_best_matches(text_features, image_features, image_ids):
# Compute the similarity between the search query and each image using the Cosine similarity
similarities = (image_features @ text_features.T).squeeze(1)
# Sort the images by their similarity score
best_image_idx = (-similarities).argsort()
# Return the image IDs of the best matches
return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
def encode_search_query(search_query, model_type):
with torch.no_grad():
# Encode and normalize the search query using the multilingual model
if model_type == "M-CLIP (multilingual ViT)":
text_encoded = st.session_state.ml_model.forward(
search_query, st.session_state.ml_tokenizer
)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
elif model_type == "J-CLIP (日本語 ViT)":
t_text = st.session_state.ja_tokenizer(
search_query,
padding=True,
return_tensors="pt",
device=st.session_state.device,
)
text_encoded = st.session_state.ja_model.get_text_features(**t_text)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
else: # model_type == legacy
text_encoded = st.session_state.rn_model(search_query)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the feature vector
return text_encoded.to(st.session_state.device)
def clip_search(search_query):
if st.session_state.search_field_value != search_query:
st.session_state.search_field_value = search_query
model_type = st.session_state.active_model
if len(search_query) >= 1:
text_features = encode_search_query(search_query, model_type)
# Compute the similarity between the descrption and each photo using the Cosine similarity
# similarities = list((text_features @ photo_features.T).squeeze(0))
# Sort the photos by their similarity score
if model_type == "M-CLIP (multilingual ViT)":
matches = find_best_matches(
text_features,
st.session_state.ml_image_features,
st.session_state.image_ids,
)
elif model_type == "J-CLIP (日本語 ViT)":
matches = find_best_matches(
text_features,
st.session_state.ja_image_features,
st.session_state.image_ids,
)
else: # model_type == legacy
matches = find_best_matches(
text_features,
st.session_state.rn_image_features,
st.session_state.image_ids,
)
st.session_state.search_image_ids = [match[0] for match in matches]
st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
def string_search():
st.session_state.disable_uploader = (
RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)"
)
if "search_field_value" in st.session_state:
clip_search(st.session_state.search_field_value)
def load_image_features():
# Load the image feature vectors
if st.session_state.vision_mode == "tiled":
ml_image_features = np.load("./image_features/tiled_ml_features.npy")
ja_image_features = np.load("./image_features/tiled_ja_features.npy")
rn_image_features = np.load("./image_features/tiled_rn_features.npy")
elif st.session_state.vision_mode == "stretched":
ml_image_features = np.load("./image_features/resized_ml_features.npy")
ja_image_features = np.load("./image_features/resized_ja_features.npy")
rn_image_features = np.load("./image_features/resized_rn_features.npy")
else: # st.session_state.vision_mode == "cropped":
ml_image_features = np.load("./image_features/cropped_ml_features.npy")
ja_image_features = np.load("./image_features/cropped_ja_features.npy")
rn_image_features = np.load("./image_features/cropped_rn_features.npy")
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
device = st.session_state.device
if device == "cpu":
ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
rn_image_features = torch.from_numpy(rn_image_features).float().to(device)
else:
ml_image_features = torch.from_numpy(ml_image_features).to(device)
ja_image_features = torch.from_numpy(ja_image_features).to(device)
rn_image_features = torch.from_numpy(rn_image_features).to(device)
st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
dim=-1, keepdim=True
)
st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
dim=-1, keepdim=True
)
st.session_state.rn_image_features = rn_image_features / rn_image_features.norm(
dim=-1, keepdim=True
)
string_search()
def init():
st.session_state.current_page = 1
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
st.session_state.device = device
# Load the open CLIP models
with st.spinner("Loading models and data, please wait..."):
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
if not RUN_LITE:
st.session_state.ml_image_model, st.session_state.ml_image_preprocess = (
load(ml_model_path, device=device, jit=False)
)
st.session_state.ml_model = (
pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
).to(device)
st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
ja_model_path, device=device, jit=False
)
st.session_state.ja_model = AutoModel.from_pretrained(
ja_model_name, trust_remote_code=True
).to(device)
st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
ja_model_name, trust_remote_code=True
)
st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
clip.load("RN50x4", device=device)
)
st.session_state.rn_model = legacy_multilingual_clip.load_model(
"M-BERT-Base-69"
).to(device)
st.session_state.rn_tokenizer = BertTokenizer.from_pretrained(
"bert-base-multilingual-cased"
)
# Load the image IDs
st.session_state.images_info = pd.read_csv("./metadata.csv")
st.session_state.images_info.set_index("filename", inplace=True)
with open("./images_list.txt", "r", encoding="utf-8") as images_list:
st.session_state.image_ids = list(images_list.read().strip().split("\n"))
st.session_state.active_model = "J-CLIP (日本語 ViT)"
st.session_state.vision_mode = "tiled"
st.session_state.search_image_ids = []
st.session_state.search_image_scores = {}
st.session_state.text_table_df = None
st.session_state.disable_uploader = (
RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
)
with st.spinner("Loading models and data, please wait..."):
load_image_features()
if "images_info" not in st.session_state:
init()
def get_overlay_vis(image, img_dim, image_model):
orig_img_dims = image.size
##### If the features are based on tiled image slices
tile_behavior = None
if st.session_state.vision_mode == "tiled":
scaled_dims = [img_dim, img_dim]
if orig_img_dims[0] > orig_img_dims[1]:
scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
if scale_ratio > 1:
scaled_dims = [scale_ratio * img_dim, img_dim]
tile_behavior = "width"
elif orig_img_dims[0] < orig_img_dims[1]:
scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
if scale_ratio > 1:
scaled_dims = [img_dim, scale_ratio * img_dim]
tile_behavior = "height"
resized_image = image.resize(scaled_dims, Image.LANCZOS)
if tile_behavior == "width":
image_tiles = []
for x in range(0, scale_ratio):
box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
image_tiles.append(resized_image.crop(box))
elif tile_behavior == "height":
image_tiles = []
for y in range(0, scale_ratio):
box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
image_tiles.append(resized_image.crop(box))
else:
image_tiles = [resized_image]
elif st.session_state.vision_mode == "stretched":
image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
else: # vision_mode == "cropped"
if orig_img_dims[0] > orig_img_dims[1]:
scale_factor = orig_img_dims[0] / orig_img_dims[1]
resized_img_dims = (round(scale_factor * img_dim), img_dim)
resized_img = image.resize(resized_img_dims)
elif orig_img_dims[0] < orig_img_dims[1]:
scale_factor = orig_img_dims[1] / orig_img_dims[0]
resized_img_dims = (img_dim, round(scale_factor * img_dim))
else:
resized_img_dims = (img_dim, img_dim)
resized_img = image.resize(resized_img_dims)
left = round((resized_img_dims[0] - img_dim) / 2)
top = round((resized_img_dims[1] - img_dim) / 2)
x_right = round(resized_img_dims[0] - img_dim) - left
x_bottom = round(resized_img_dims[1] - img_dim) - top
right = resized_img_dims[0] - x_right
bottom = resized_img_dims[1] - x_bottom
# Crop the center of the image
image_tiles = [resized_img.crop((left, top, right, bottom))]
image_visualizations = []
image_features = []
image_similarities = []
if st.session_state.active_model == "M-CLIP (multilingual ViT)":
text_features = st.session_state.ml_model.forward(
st.session_state.search_field_value, st.session_state.ml_tokenizer
)
if st.session_state.device == "cpu":
text_features = text_features.float().to(st.session_state.device)
else:
text_features = text_features.to(st.session_state.device)
for altered_image in image_tiles:
p_image = (
st.session_state.ml_image_preprocess(altered_image)
.unsqueeze(0)
.to(st.session_state.device)
)
vis_t, img_feats, similarity = interpret_vit_overlapped(
p_image.type(image_model.dtype),
text_features.type(image_model.dtype),
image_model.visual,
st.session_state.device,
img_dim=img_dim,
)
image_visualizations.append(vis_t)
image_features.append(img_feats)
image_similarities.append(similarity.item())
elif st.session_state.active_model == "J-CLIP (日本語 ViT)":
t_text = st.session_state.ja_tokenizer(
st.session_state.search_field_value,
return_tensors="pt",
device=st.session_state.device,
)
text_features = st.session_state.ja_model.get_text_features(**t_text)
if st.session_state.device == "cpu":
text_features = text_features.float().to(st.session_state.device)
else:
text_features = text_features.to(st.session_state.device)
for altered_image in image_tiles:
p_image = (
st.session_state.ja_image_preprocess(altered_image)
.unsqueeze(0)
.to(st.session_state.device)
)
vis_t, img_feats, similarity = interpret_vit_overlapped(
p_image.type(image_model.dtype),
text_features.type(image_model.dtype),
image_model.visual,
st.session_state.device,
img_dim=img_dim,
)
image_visualizations.append(vis_t)
image_features.append(img_feats)
image_similarities.append(similarity.item())
else: # st.session_state.active_model == Legacy
text_features = st.session_state.rn_model(st.session_state.search_field_value)
if st.session_state.device == "cpu":
text_features = text_features.float().to(st.session_state.device)
else:
text_features = text_features.to(st.session_state.device)
for altered_image in image_tiles:
p_image = (
st.session_state.rn_image_preprocess(altered_image)
.unsqueeze(0)
.to(st.session_state.device)
)
vis_t = interpret_rn_overlapped(
p_image.type(image_model.dtype),
text_features.type(image_model.dtype),
image_model.visual,
GradCAM,
st.session_state.device,
img_dim=img_dim,
)
text_features_norm = text_features.norm(dim=-1, keepdim=True)
text_features_new = text_features / text_features_norm
image_feats = image_model.encode_image(p_image.type(image_model.dtype))
image_feats_norm = image_feats.norm(dim=-1, keepdim=True)
image_feats_new = image_feats / image_feats_norm
similarity = image_feats_new[0].dot(text_features_new[0])
image_visualizations.append(vis_t)
image_features.append(p_image)
image_similarities.append(similarity.item())
transform = ToPILImage()
vis_images = [transform(vis_t) for vis_t in image_visualizations]
if st.session_state.vision_mode == "cropped":
resized_img.paste(vis_images[0], (left, top))
vis_images = [resized_img]
if orig_img_dims[0] > orig_img_dims[1]:
scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
else:
scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
if tile_behavior == "width":
vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
for x, v_img in enumerate(vis_images):
vis_image.paste(v_img, (x * img_dim, 0))
activations_image = vis_image.resize(scaled_dims)
elif tile_behavior == "height":
vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
for y, v_img in enumerate(vis_images):
vis_image.paste(v_img, (0, y * img_dim))
activations_image = vis_image.resize(scaled_dims)
else:
activations_image = vis_images[0].resize(scaled_dims)
return activations_image, image_features, np.mean(image_similarities)
def visualize_gradcam(image):
if "search_field_value" not in st.session_state:
return
header_cols = st.columns([80, 20], vertical_alignment="bottom")
with header_cols[0]:
st.title("Image + query activation gradients")
with header_cols[1]:
if st.button("Close"):
st.rerun()
if st.session_state.active_model == "M-CLIP (multilingual ViT)":
img_dim = 240
image_model = st.session_state.ml_image_model
# Sometimes used for token importance viz
tokenized_text = st.session_state.ml_tokenizer.tokenize(
st.session_state.search_field_value
)
elif st.session_state.active_model == "Legacy (multilingual ResNet)":
img_dim = 288
image_model = st.session_state.rn_image_model
# Sometimes used for token importance viz
tokenized_text = st.session_state.rn_tokenizer.tokenize(
st.session_state.search_field_value
)
else: # J-CLIP
img_dim = 224
image_model = st.session_state.ja_image_model
# Sometimes used for token importance viz
tokenized_text = st.session_state.ja_tokenizer.tokenize(
st.session_state.search_field_value
)
st.image(image)
with st.spinner("Calculating..."):
# info_text = st.text("Calculating activation regions...")
activations_image, image_features, similarity_score = get_overlay_vis(
image, img_dim, image_model
)
st.markdown(
f"**Query text:** {st.session_state.search_field_value} | **Approx. image relevance:** {round(similarity_score.item(), 3)}"
)
st.image(activations_image)
# image_io = BytesIO()
# activations_image.save(image_io, "PNG")
# dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode(
# "ascii"
# )
# st.html(
# f"""<div style="display: flex; flex-direction: column; align-items: center;">
# <img src="{dataurl}" />
# </div>"""
# )
tokenized_text = [
tok.replace("▁", "").replace("#", "") for tok in tokenized_text if tok != "▁"
]
tokenized_text = [
tok
for tok in tokenized_text
if tok
not in ["s", "ed", "a", "the", "an", "ing", "て", "に", "の", "は", "と", "た"]
]
if (
len(tokenized_text) > 1
and len(tokenized_text) < 25
and st.button(
"Calculate text importance (may take some time)",
)
):
scores_per_token = {}
progress_text = f"Processing {len(tokenized_text)} text tokens"
progress_bar = st.progress(0.0, text=progress_text)
for t, tok in enumerate(tokenized_text):
token = tok
for img_feats in image_features:
if st.session_state.active_model == "Legacy (multilingual ResNet)":
word_rel = rn_perword_relevance(
img_feats,
st.session_state.search_field_value,
image_model,
tokenize,
GradCAM,
st.session_state.device,
token,
data_only=True,
img_dim=img_dim,
)
else:
word_rel = vit_perword_relevance(
img_feats,
st.session_state.search_field_value,
image_model,
tokenize,
st.session_state.device,
token,
img_dim=img_dim,
)
avg_score = np.mean(word_rel)
if avg_score == 0 or np.isnan(avg_score):
continue
if token not in scores_per_token:
scores_per_token[token] = [1 / avg_score]
else:
scores_per_token[token].append(1 / avg_score)
progress_bar.progress(
(t + 1) / len(tokenized_text),
text=f"Processing token {t+1} of {len(tokenized_text)}",
)
progress_bar.empty()
avg_scores_per_token = [
np.mean(scores_per_token[tok]) for tok in list(scores_per_token.keys())
]
normed_scores = torch.softmax(torch.tensor(avg_scores_per_token), dim=0)
token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
st.session_state.text_table_df = pd.DataFrame(
{"token": list(scores_per_token.keys()), "importance": token_scores}
)
st.markdown("**Importance of each text token to relevance score**")
st.table(st.session_state.text_table_df)
@st.dialog(" ", width="large")
def image_modal(image):
visualize_gradcam(image)
def vis_known_image(vis_image_id):
image_url = st.session_state.images_info.loc[vis_image_id]["image_url"]
image_response = requests.get(image_url)
image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF", "PNG"])
image = image.convert("RGB")
image_modal(image)
def vis_uploaded_image():
uploaded_file = st.session_state.uploaded_image
if uploaded_file is not None:
# To read file as bytes:
bytes_data = uploaded_file.getvalue()
image = Image.open(BytesIO(bytes_data), formats=["JPEG", "GIF", "PNG"])
image = image.convert("RGB")
image_modal(image)
def format_vision_mode(mode_stub):
return mode_stub.capitalize()
st.title("Explore Japanese visual aesthetics with CLIP models")
st.markdown(
"""
<style>
[data-testid=stImageCaption] {
padding: 0 0 0 0;
}
[data-testid=stVerticalBlockBorderWrapper] {
line-height: 1.2;
}
[data-testid=stVerticalBlock] {
gap: .75rem;
}
[data-testid=baseButton-secondary] {
min-height: 1rem;
padding: 0 0.75rem;
margin: 0 0 1rem 0;
}
div[aria-label="dialog"]>button[aria-label="Close"] {
display: none;
}
[data-testid=stFullScreenFrame] {
display: flex;
flex-direction: column;
align-items: center;
}
</style>
""",
unsafe_allow_html=True,
)
search_row = st.columns([45, 8, 8, 10, 1, 8, 20], vertical_alignment="center")
with search_row[0]:
search_field = st.text_input(
label="search",
label_visibility="collapsed",
placeholder="Type something, or click a suggested search below.",
on_change=string_search,
key="search_field_value",
)
with search_row[1]:
st.button(
"Search", on_click=string_search, use_container_width=True, type="primary"
)
with search_row[2]:
st.markdown("**Vision mode:**")
with search_row[3]:
st.selectbox(
"Vision mode",
options=["tiled", "stretched", "cropped"],
key="vision_mode",
help="How to consider images that aren't square",
on_change=load_image_features,
format_func=format_vision_mode,
label_visibility="collapsed",
)
with search_row[4]:
st.empty()
with search_row[5]:
st.markdown("**CLIP model:**")
with search_row[6]:
st.selectbox(
"CLIP Model:",
options=[
"J-CLIP (日本語 ViT)",
"M-CLIP (multilingual ViT)",
"Legacy (multilingual ResNet)",
],
key="active_model",
on_change=string_search,
label_visibility="collapsed",
)
canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
with canned_searches[0]:
st.markdown("**Suggested searches:**")
if st.session_state.active_model == "J-CLIP (日本語 ViT)":
with canned_searches[1]:
st.button(
"間",
on_click=clip_search,
args=["間"],
use_container_width=True,
)
with canned_searches[2]:
st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True)
with canned_searches[3]:
st.button("山", on_click=clip_search, args=["山"], use_container_width=True)
with canned_searches[4]:
st.button(
"花に酔えり 羽織着て刀 さす女",
on_click=clip_search,
args=["花に酔えり 羽織着て刀 さす女"],
use_container_width=True,
)
else:
with canned_searches[1]:
st.button(
"negative space",
on_click=clip_search,
args=["negative space"],
use_container_width=True,
)
with canned_searches[2]:
st.button("間", on_click=clip_search, args=["間"], use_container_width=True)
with canned_searches[3]:
st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
with canned_searches[4]:
st.button(
"αρνητικός χώρος",
on_click=clip_search,
args=["αρνητικός χώρος"],
use_container_width=True,
)
controls = st.columns([25, 25, 20, 35], gap="large", vertical_alignment="center")
with controls[0]:
im_per_pg = st.columns([30, 70], vertical_alignment="center")
with im_per_pg[0]:
st.markdown("**Images/page:**")
with im_per_pg[1]:
batch_size = st.select_slider(
"Images/page:", range(10, 50, 10), label_visibility="collapsed"
)
with controls[1]:
im_per_row = st.columns([30, 70], vertical_alignment="center")
with im_per_row[0]:
st.markdown("**Images/row:**")
with im_per_row[1]:
row_size = st.select_slider(
"Images/row:", range(1, 6), value=5, label_visibility="collapsed"
)
num_batches = ceil(len(st.session_state.image_ids) / batch_size)
with controls[2]:
pager = st.columns([40, 60], vertical_alignment="center")
with pager[0]:
st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
with pager[1]:
st.number_input(
"Page",
min_value=1,
max_value=num_batches,
step=1,
label_visibility="collapsed",
key="current_page",
)
with controls[3]:
st.file_uploader(
"Upload an image",
type=["jpg", "jpeg", "gif", "png"],
key="uploaded_image",
label_visibility="collapsed",
on_change=vis_uploaded_image,
disabled=st.session_state.disable_uploader,
)
if len(st.session_state.search_image_ids) == 0:
batch = []
else:
batch = st.session_state.search_image_ids[
(st.session_state.current_page - 1) * batch_size : st.session_state.current_page
* batch_size
]
grid = st.columns(row_size)
col = 0
for image_id in batch:
with grid[col]:
link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
2
]
# st.image(
# st.session_state.images_info.loc[image_id]["image_url"],
# caption=st.session_state.images_info.loc[image_id]["caption"],
# )
st.html(
f"""<div style="display: flex; flex-direction: column; align-items: center">
<img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" />
<div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div>
</div>"""
)
st.caption(
f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px">
<a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
<div>""",
unsafe_allow_html=True,
)
if not (
RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)"
):
st.button(
"Explain this",
on_click=vis_known_image,
args=[image_id],
use_container_width=True,
key=image_id,
)
else:
st.empty()
col = (col + 1) % row_size