Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import base64 | |
from io import BytesIO | |
from multilingual_clip import pt_multilingual_clip | |
from transformers import CLIPTokenizerFast, AutoTokenizer, CLIPModel | |
import torch | |
import logging | |
from os import environ | |
from parse import parse | |
from clickhouse_connect import get_client | |
environ['TOKENIZERS_PARALLELISM'] = 'true' | |
db_name_map = { | |
"Unsplash Photos 25K": lambda feat: f"mqdb_demo.unsplash_25k_{feat}_indexer", | |
"RSICD: Remote Sensing Images 11K": lambda feat: f"mqdb_demo.rsicd_{feat}_b_32", | |
} | |
feat_name_map = { | |
'Vanilla CLIP': "clip", | |
'CLIP finetuned on RSICD': "cliprsicd" | |
} | |
DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer" | |
DIMS = 512 | |
# Ignore some bad links (broken in the dataset already) | |
BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8', | |
'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'} | |
def init_db(): | |
""" Initialize the Database Connection | |
Returns: | |
meta_field: Meta field that records if an image is viewed or not | |
client: Database connection object | |
""" | |
r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"]) | |
client = get_client( | |
host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"], | |
interface=r['http_pre'], | |
) | |
meta_field = {} | |
return meta_field, client | |
def init_query_num(): | |
print("init query_num") | |
return 0 | |
def query(xq, top_k=10): | |
""" Query TopK matched w.r.t a given vector | |
Args: | |
xq (numpy.ndarray or list of floats): Query vector | |
top_k (int, optional): Number of matched vectors. Defaults to 10. | |
Returns: | |
matches: list of Records object. Keys referrring to selected columns | |
""" | |
attempt = 0 | |
xq = xq / np.linalg.norm(xq) | |
while attempt < 3: | |
try: | |
xq_s = f"[{', '.join([str(float(fnum)) for fnum in list(xq)])}]" | |
print('Excluded pre:', st.session_state.meta) | |
if len(st.session_state.meta) > 0: | |
exclude_list = ','.join( | |
[f'\'{i}\'' for i, v in st.session_state.meta.items() if v >= 1]) | |
print("Excluded:", exclude_list) | |
# Using PREWHERE allows you to do column filter before vector search | |
xc = st.session_state.index.query(f"SELECT id, url, vector,\ | |
distance(vector, {xq_s}) AS dist\ | |
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \ | |
WHERE id NOT IN ({exclude_list}) ORDER BY dist LIMIT {top_k}").named_results() | |
else: | |
xc = st.session_state.index.query(f"SELECT id, url, vector,\ | |
distance(vector, {xq_s}) AS dist\ | |
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \ | |
ORDER BY dist LIMIT {top_k}").named_results() | |
real_xc = st.session_state.index.query(f"SELECT id, url, vector,\ | |
distance(vector, {xq_s}) AS dist \ | |
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \ | |
ORDER BY dist LIMIT {top_k}").named_results() | |
top_k = [{k: v for k, v in r.items()} for r in real_xc] | |
xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or | |
st.session_state.meta[xi['id']] < 1] | |
logging.info( | |
f'{len(xc)} records returned, {[_i["id"] for _i in xc]}') | |
matches = xc | |
break | |
except Exception as e: | |
# force reload if we have trouble on connections or something else | |
logging.warning(str(e)) | |
_, st.session_state.index = init_db() | |
attempt += 1 | |
matches = [] | |
if len(matches) == 0: | |
logging.error(f"No matches found for '{DB_NAME}'") | |
return matches, top_k | |
def init_random_query(): | |
xq = np.random.rand(DIMS).tolist() | |
return xq, xq.copy() | |
class Classifier: | |
""" Zero-shot Classifier | |
This Classifier provides proxy regarding to the user's reaction to the probed images. | |
The proxy will replace the original query vector generated by prompted vector and finally | |
give the user a satisfying retrieval result. | |
This can be commonly seen in a recommendation system. The classifier will recommend more | |
precise result as it accumulating user's activity. | |
""" | |
def __init__(self, xq: list): | |
# initialize model with DIMS input size and 1 output | |
# note that the bias is ignored, as we only focus on the inner product result | |
self.model = torch.nn.Linear(DIMS, 1, bias=False) | |
# convert initial query `xq` to tensor parameter to init weights | |
init_weight = torch.Tensor(xq).reshape(1, -1) | |
self.model.weight = torch.nn.Parameter(init_weight) | |
# init loss and optimizer | |
self.loss = torch.nn.BCEWithLogitsLoss() | |
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) | |
def fit(self, X: list, y: list, iters: int = 5): | |
# convert X and y to tensor | |
X = torch.Tensor(X) | |
y = torch.Tensor(y).reshape(-1, 1) | |
for i in range(iters): | |
# zero gradients | |
self.optimizer.zero_grad() | |
# Normalize the weight before inference | |
# This will constrain the gradient or you will have an explosion on query vector | |
self.model.weight.data = self.model.weight.data / \ | |
torch.norm(self.model.weight.data, p=2, dim=-1) | |
# forward pass | |
out = self.model(X) | |
# compute loss | |
loss = self.loss(out, y) | |
# backward pass | |
loss.backward() | |
# update weights | |
self.optimizer.step() | |
def get_weights(self): | |
xq = self.model.weight.detach().numpy()[0].tolist() | |
return xq | |
class NormalizingLayer(torch.nn.Module): | |
def forward(self, x): | |
return x / torch.norm(x, dim=-1, keepdim=True) | |
def card(i, url): | |
return f'<img id="img{i}" src="{url}" width="200px;">' | |
def card_with_conf(i, conf, url): | |
conf = "%.4f" % (conf) | |
return f'<img id="img{i}" src="{url}" width="200px;" style="margin:50px 50px"><div><p><b>Relevance: {conf}</b></p></div>' | |
def get_top_k(xq, top_k=9): | |
""" wrapper function for query | |
Args: | |
xq (numpy.ndarray or list of floats): Query vector | |
top_k (int, optional): Number of returned vectors. Defaults to 9. | |
Returns: | |
matches: See `query()` | |
""" | |
matches = query( | |
xq, top_k=top_k | |
) | |
return matches | |
def tune(X, y, iters=2): | |
""" Train the Zero-shot Classifier | |
Args: | |
X (numpy.ndarray): Input vectors (retreived vectors) | |
y (list of floats or numpy.ndarray): Scores given by user | |
iters (int, optional): iterations of updates to be run | |
""" | |
assert len(X) == len(y) | |
# train the classifier | |
st.session_state.clf.fit(X, y, iters=iters) | |
# extract new vector | |
st.session_state.xq = st.session_state.clf.get_weights() | |
def refresh_index(): | |
""" Clean the session | |
""" | |
del st.session_state["meta"] | |
st.session_state.meta = {} | |
st.session_state.query_num = 0 | |
logging.info(f"Refresh for '{st.session_state.meta}'") | |
init_db.clear() | |
# refresh session states | |
st.session_state.meta, st.session_state.index = init_db() | |
del st.session_state.clf, st.session_state.xq | |
def calc_dist(): | |
xq = np.array(st.session_state.xq) | |
orig_xq = np.array(st.session_state.orig_xq) | |
return np.linalg.norm(xq - orig_xq) | |
def submit(): | |
""" Tune the model w.r.t given score from user. | |
""" | |
st.session_state.query_num += 1 | |
matches = st.session_state.matches | |
velocity = 1 # st.session_state.velocity | |
scores = {} | |
states = [ | |
st.session_state[f"input{i}"] for i in range(len(matches)) | |
] | |
for i, match in enumerate(matches): | |
scores[match['id']] = float(states[i]) | |
# reset states to 1.0 | |
for i in range(len(matches)): | |
st.session_state[f"input{i}"] = 1.0 | |
# get training data and labels | |
X = list([match['vector'] for match in matches]) | |
y = [v for v in list(scores.values())] | |
tune(X, y, iters=int(st.session_state.iters)) | |
# update record metadata after training | |
for match in matches: | |
st.session_state.meta[match['id']] = 1 | |
logging.info(f"Exclude List: {st.session_state.meta}") | |
def delete_element(element): | |
del element | |
def init_clip_mlang(): | |
""" Initialize CLIP Model | |
Returns: | |
Tokenizer: CLIPTokenizerFast (which convert words into embeddings) | |
""" | |
MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32' | |
clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
return tokenizer, clip | |
def init_clip_vanilla(): | |
""" Initialize CLIP Model | |
Returns: | |
Tokenizer: CLIPTokenizerFast (which convert words into embeddings) | |
""" | |
MODEL_ID = "openai/clip-vit-base-patch32" | |
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID) | |
clip = CLIPModel.from_pretrained(MODEL_ID) | |
return tokenizer, clip | |
def init_clip_rsicd(): | |
""" Initialize CLIP Model | |
Returns: | |
Tokenizer: CLIPTokenizerFast (which convert words into embeddings) | |
""" | |
MODEL_ID = "flax-community/clip-rsicd" | |
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID) | |
clip = CLIPModel.from_pretrained(MODEL_ID) | |
return tokenizer, clip | |
def prompt2vec_mlang(prompt: str, tokenizer, clip): | |
""" Convert prompt into a computational vector | |
Args: | |
prompt (str): Text to be tokenized | |
Returns: | |
xq: vector from the tokenizer, representing the original prompt | |
""" | |
out = clip.forward(prompt, tokenizer) | |
xq = out.squeeze(0).cpu().detach().numpy().tolist() | |
return xq | |
def prompt2vec_vanilla(prompt: str, tokenizer, clip): | |
inputs = tokenizer(prompt, return_tensors='pt') | |
out = clip.get_text_features(**inputs) | |
xq = out.squeeze(0).cpu().detach().numpy().tolist() | |
return xq | |
st.markdown(""" | |
<link | |
rel="stylesheet" | |
href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap" | |
/> | |
""", unsafe_allow_html=True) | |
messages = [ | |
f""" | |
Find most relevant examples from a large visual dataset by combining text query and few-shot learning. | |
""", | |
f""" | |
Then then you can adjust the weight on each image. Those weights should **represent how much it | |
can meet your preference**. You can either choose the images that match your prompt or change | |
your mind. | |
You might notice that there is a iteration slide bar on the top of all retrieved images. This will | |
control the speed of changes on vectors. More **iterations** will change the vector faster while | |
lower values on **iterations** will make the retrieval smoother. | |
""", | |
f""" | |
This example will manage to train a classifier to distinguish between samples you want and samples | |
you don't want. By initializing the weight from prompt, you can get a good enough classifier to cluster | |
images you want to search. If you think the result is not as perfect as you expected, you can also | |
supervise the classifer with **Relevance** annotation. If you cannot see any difference in Top-K | |
Retrieved results, try to enlarge **Number of Iteration** | |
""", | |
# TODO @ fangruil: fill the link with our tech blog | |
f""" | |
The app uses the [MyScale](http://mqdb.page.moqi.ai/mqdb-docs/) to store and query images | |
using vector search. All images are sourced from the | |
[Unsplash Lite dataset](https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip) | |
and encoded using [OpenAI's CLIP](https://huggingface.co/openai/clip-vit-base-patch32). We explain how | |
it all works [here](). | |
""" | |
] | |
text_model_map = { | |
'Multi Lingual': {'Vanilla CLIP': [prompt2vec_mlang, ]}, | |
'English': {'Vanilla CLIP': [prompt2vec_vanilla, ], | |
'CLIP finetuned on RSICD': [prompt2vec_vanilla, ], | |
} | |
} | |
with st.spinner("Connecting DB..."): | |
st.session_state.meta, st.session_state.index = init_db() | |
with st.spinner("Loading Models..."): | |
# Initialize CLIP model | |
if 'xq' not in st.session_state: | |
text_model_map['Multi Lingual']['Vanilla CLIP'].append( | |
init_clip_mlang()) | |
text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla()) | |
text_model_map['English']['CLIP finetuned on RSICD'].append( | |
init_clip_rsicd()) | |
st.session_state.query_num = 0 | |
if 'xq' not in st.session_state: | |
# If it's a fresh start | |
if st.session_state.query_num < len(messages): | |
msg = messages[st.session_state.query_num] | |
else: | |
msg = messages[-1] | |
prompt = '' | |
# Basic Layout | |
with st.container(): | |
if 'prompt' in st.session_state: | |
del st.session_state.prompt | |
st.title("Visual Dataset Explorer") | |
start = [st.empty(), st.empty(), st.empty(), st.empty(), | |
st.empty(), st.empty(), st.empty(), st.empty()] | |
start[0].info(msg) | |
start_col = start[1].columns(3) | |
st.session_state.db_name_ref = start_col[0].selectbox( | |
"Select Database:", list(db_name_map.keys())) | |
st.session_state.lang = start_col[1].selectbox( | |
"Select Language:", list(text_model_map.keys())) | |
st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:", | |
list(text_model_map[st.session_state.lang].keys())) | |
if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K": | |
start[2].warning('If you are searching for Remote Sensing Images, \ | |
try to use prompt "An aerial photograph of <your-real-query>" \ | |
to obtain best search experience!') | |
if len(prompt) > 0: | |
st.session_state.prompt = prompt.replace(' ', '_') | |
start[4].markdown( | |
'<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\ | |
<p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>', | |
unsafe_allow_html=True) | |
upld_model = start[6].file_uploader( | |
"Or you can upload your previous run!", type='onnx') | |
upld_btn = start[7].button( | |
"Use Loaded Weights", disabled=upld_model is None) | |
prompt = start[3].text_input( | |
"Prompt:", | |
value="An aerial photograph of "if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K" else "", | |
placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...",) | |
with start[5]: | |
col = st.columns(8) | |
has_no_prompt = (len(prompt) == 0 and upld_model is None) | |
prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0) | |
random_xq = col[7].button("Random", disabled=not ( | |
len(prompt) == 0 and upld_model is None)) | |
if random_xq: | |
# Randomly pick a vector to query | |
xq, orig_xq = init_random_query() | |
st.session_state.xq = xq | |
st.session_state.orig_xq = orig_xq | |
_ = [elem.empty() for elem in start] | |
elif prompt_xq or upld_btn: | |
if upld_model is not None: | |
# Import vector from a file | |
import onnx | |
from onnx import numpy_helper | |
_model = onnx.load(upld_model) | |
weights = _model.graph.initializer | |
assert len(weights) == 1 | |
xq = numpy_helper.to_array(weights[0]).tolist() | |
assert len(xq) == DIMS | |
st.session_state.prompt = upld_model.name.split(".onnx")[ | |
0].replace(' ', '_') | |
else: | |
print(f"Input prompt is {prompt}") | |
# Tokenize the vectors | |
p2v_func, args = text_model_map[st.session_state.lang][st.session_state.feat_name] | |
xq = p2v_func(prompt, *args) | |
st.session_state.xq = xq | |
st.session_state.orig_xq = xq | |
_ = [elem.empty() for elem in start] | |
if 'xq' in st.session_state: | |
# If it is not a fresh start | |
if st.session_state.query_num+1 < len(messages): | |
msg = messages[st.session_state.query_num+1] | |
else: | |
msg = messages[-1] | |
# initialize classifier | |
if 'clf' not in st.session_state: | |
st.session_state.clf = Classifier(st.session_state.xq) | |
# if we want to display images we end up here | |
st.info(msg) | |
# first retrieve images from pinecone | |
st.session_state.matches, st.session_state.top_k = get_top_k( | |
st.session_state.clf.get_weights(), top_k=9) | |
# export the model into executable ONNX | |
st.session_state.dnld_model = BytesIO() | |
torch.onnx.export(torch.nn.Sequential(NormalizingLayer(), st.session_state.clf.model), | |
torch.as_tensor(st.session_state.xq).reshape(1, -1), | |
st.session_state.dnld_model, | |
input_names=['input'], | |
output_names=['output']) | |
with st.container(): | |
with st.sidebar: | |
with st.container(): | |
st.header("Top K Nearest in Database") | |
for i, k in enumerate(st.session_state.top_k): | |
url = k["url"] | |
url += "?q=75&fm=jpg&w=200&fit=max" | |
if k["id"] not in BAD_IDS: | |
disabled = False | |
else: | |
disable = True | |
dist = np.matmul(st.session_state.clf.get_weights() / np.linalg.norm(st.session_state.clf.get_weights()), | |
np.array(k["vector"]).T) | |
st.markdown(card_with_conf(i, dist, url), | |
unsafe_allow_html=True) | |
dnld_nam = st.text_input('Download Name:', | |
f'{(st.session_state.prompt if "prompt" in st.session_state else "model")}.onnx', | |
max_chars=50) | |
dnld_btn = st.download_button('Download your classifier!', | |
st.session_state.dnld_model, | |
dnld_nam,) | |
# once retrieved, display them alongside checkboxes in a form | |
with st.form("batch", clear_on_submit=False): | |
st.session_state.iters = st.slider( | |
"Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2) | |
col = st.columns([1, 9]) | |
col[0].form_submit_button("Train!", on_click=submit) | |
col[1].form_submit_button( | |
"Choose a new prompt", on_click=refresh_index) | |
# we have three columns in the form | |
cols = st.columns(3) | |
for i, match in enumerate(st.session_state.matches): | |
# find good url | |
url = match["url"] | |
url += "?q=75&fm=jpg&w=200&fit=max" | |
if match["id"] not in BAD_IDS: | |
disabled = False | |
else: | |
disable = True | |
# the card shows an image and a checkbox | |
cols[i % 3].markdown(card(i, url), unsafe_allow_html=True) | |
# we access the values of the checkbox via st.session_state[f"input{i}"] | |
cols[i % 3].slider( | |
"Relevance", | |
min_value=0.0, | |
max_value=1.0, | |
value=1.0, | |
step=0.05, | |
key=f"input{i}", | |
disabled=disabled | |
) | |