Spaces:
Running
Running
import enum | |
from turtle import onclick | |
import streamlit as st | |
import numpy as np | |
import base64 | |
from io import BytesIO | |
from multilingual_clip import pt_multilingual_clip | |
from transformers import CLIPTokenizerFast, AutoTokenizer | |
import torch | |
import logging | |
from os import environ | |
environ['TOKENIZERS_PARALLELISM'] = 'true' | |
from myscaledb import Client | |
DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer" | |
MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32' | |
DIMS = 512 | |
# Ignore some bad links (broken in the dataset already) | |
BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8', 'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'} | |
def init_clip(): | |
""" Initialize CLIP Model | |
Returns: | |
Tokenizer: CLIPTokenizerFast (which convert words into embeddings) | |
""" | |
clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
return tokenizer, clip | |
def init_db(): | |
""" Initialize the Database Connection | |
Returns: | |
meta_field: Meta field that records if an image is viewed or not | |
client: Database connection object | |
""" | |
client = Client(url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"]) | |
# We can check if the connection is alive | |
assert client.is_alive() | |
meta_field = {} | |
return meta_field, client | |
def init_query_num(): | |
print("init query_num") | |
return 0 | |
def query(xq, top_k=10): | |
""" Query TopK matched w.r.t a given vector | |
Args: | |
xq (numpy.ndarray or list of floats): Query vector | |
top_k (int, optional): Number of matched vectors. Defaults to 10. | |
Returns: | |
matches: list of Records object. Keys referrring to selected columns | |
""" | |
attempt = 0 | |
xq = xq / np.linalg.norm(xq) | |
while attempt < 3: | |
try: | |
xq_s = f"[{', '.join([str(float(fnum)) for fnum in list(xq)])}]" | |
print('Excluded pre:', st.session_state.meta) | |
if len(st.session_state.meta) > 0: | |
exclude_list = ','.join([f'\'{i}\'' for i, v in st.session_state.meta.items() if v >= 1]) | |
print("Excluded:", exclude_list) | |
# Using PREWHERE allows you to do column filter before vector search | |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\ | |
distance('topK={top_k}')(vector, {xq_s}) AS dist\ | |
FROM {DB_NAME} PREWHERE id NOT IN ({exclude_list})") | |
else: | |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\ | |
distance('topK={top_k}')(vector, {xq_s}) AS dist\ | |
FROM {DB_NAME}") | |
# real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\ | |
# 1 - arraySum(arrayMap((x, y) -> x * y, {xq_s}, vector)) AS dist\ | |
# FROM {DB_NAME} ORDER BY dist LIMIT {top_k}") | |
# FIXME: This is causing freezing on DB | |
real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\ | |
distance('topK={top_k}')(vector, {xq_s}) AS dist\ | |
FROM {DB_NAME}") | |
top_k = real_xc | |
xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or \ | |
st.session_state.meta[xi['id']] < 1] | |
logging.info(f'{len(xc)} records returned, {[_i["id"] for _i in xc]}') | |
matches = xc | |
break | |
except Exception as e: | |
# force reload if we have trouble on connections or something else | |
logging.warning(str(e)) | |
_, st.session_state.index = init_db() | |
attempt += 1 | |
matches = [] | |
if len(matches) == 0: | |
logging.error(f"No matches found for '{DB_NAME}'") | |
return matches, top_k | |
def init_random_query(): | |
xq = np.random.rand(DIMS).tolist() | |
return xq, xq.copy() | |
class Classifier: | |
""" Zero-shot Classifier | |
This Classifier provides proxy regarding to the user's reaction to the probed images. | |
The proxy will replace the original query vector generated by prompted vector and finally | |
give the user a satisfying retrieval result. | |
This can be commonly seen in a recommendation system. The classifier will recommend more | |
precise result as it accumulating user's activity. | |
""" | |
def __init__(self, xq: list): | |
# initialize model with DIMS input size and 1 output | |
# note that the bias is ignored, as we only focus on the inner product result | |
self.model = torch.nn.Linear(DIMS, 1, bias=False) | |
# convert initial query `xq` to tensor parameter to init weights | |
init_weight = torch.Tensor(xq).reshape(1, -1) | |
self.model.weight = torch.nn.Parameter(init_weight) | |
# init loss and optimizer | |
self.loss = torch.nn.BCEWithLogitsLoss() | |
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) | |
def fit(self, X: list, y: list, iters: int = 5): | |
# convert X and y to tensor | |
X = torch.Tensor(X) | |
y = torch.Tensor(y).reshape(-1, 1) | |
for i in range(iters): | |
# zero gradients | |
self.optimizer.zero_grad() | |
# Normalize the weight before inference | |
# This will constrain the gradient or you will have an explosion on query vector | |
self.model.weight.data = self.model.weight.data / torch.norm(self.model.weight.data, p=2, dim=-1) | |
# forward pass | |
out = self.model(X) | |
# compute loss | |
loss = self.loss(out, y) | |
# backward pass | |
loss.backward() | |
# update weights | |
self.optimizer.step() | |
def get_weights(self): | |
xq = self.model.weight.detach().numpy()[0].tolist() | |
return xq | |
def prompt2vec(prompt: str): | |
""" Convert prompt into a computational vector | |
Args: | |
prompt (str): Text to be tokenized | |
Returns: | |
xq: vector from the tokenizer, representing the original prompt | |
""" | |
# inputs = tokenizer(prompt, return_tensors='pt') | |
# out = clip.get_text_features(**inputs) | |
out = clip.forward(prompt, tokenizer) | |
xq = out.squeeze(0).cpu().detach().numpy().tolist() | |
return xq | |
def pil_to_bytes(img): | |
""" Convert a Pillow image into base64 | |
Args: | |
img (PIL.Image): Pillow (PIL) Image | |
Returns: | |
img_bin: image in base64 format | |
""" | |
with BytesIO() as buf: | |
img.save(buf, format='jpeg') | |
img_bin = buf.getvalue() | |
img_bin = base64.b64encode(img_bin).decode('utf-8') | |
return img_bin | |
def card(i, url): | |
return f'<img id="img{i}" src="{url}" width="200px;">' | |
def card_with_conf(i, conf, url): | |
conf = "%.4f"%(conf) | |
return f'<img id="img{i}" src="{url}" width="200px;" style="margin:50px 50px"><b>Relevance: {conf}</b>' | |
def get_top_k(xq, top_k=9): | |
""" wrapper function for query | |
Args: | |
xq (numpy.ndarray or list of floats): Query vector | |
top_k (int, optional): Number of returned vectors. Defaults to 9. | |
Returns: | |
matches: See `query()` | |
""" | |
matches = query( | |
xq, top_k=top_k | |
) | |
return matches | |
def tune(X, y, iters=2): | |
""" Train the Zero-shot Classifier | |
Args: | |
X (numpy.ndarray): Input vectors (retreived vectors) | |
y (list of floats or numpy.ndarray): Scores given by user | |
iters (int, optional): iterations of updates to be run | |
""" | |
# train the classifier | |
st.session_state.clf.fit(X, y, iters=iters) | |
# extract new vector | |
st.session_state.xq = st.session_state.clf.get_weights() | |
def refresh_index(): | |
""" Clean the session | |
""" | |
del st.session_state["meta"] | |
st.session_state.meta = {} | |
st.session_state.query_num = 0 | |
logging.info(f"Refresh for '{st.session_state.meta}'") | |
init_db.clear() | |
# refresh session states | |
st.session_state.meta, st.session_state.index = init_db() | |
del st.session_state.clf, st.session_state.xq | |
def calc_dist(): | |
xq = np.array(st.session_state.xq) | |
orig_xq = np.array(st.session_state.orig_xq) | |
return np.linalg.norm(xq - orig_xq) | |
def submit(): | |
""" Tune the model w.r.t given score from user. | |
""" | |
st.session_state.query_num += 1 | |
matches = st.session_state.matches | |
velocity = 1 #st.session_state.velocity | |
scores = {} | |
states = [ | |
st.session_state[f"input{i}"] for i in range(len(matches)) | |
] | |
for i, match in enumerate(matches): | |
scores[match['id']] = float(states[i]) | |
# reset states to 1.0 | |
for i in range(len(matches)): | |
st.session_state[f"input{i}"] = 1.0 | |
# get training data and labels | |
X = list([match['vector'] for match in matches]) | |
y = [v for v in list(scores.values())] | |
tune(X, y, iters=int(st.session_state.iters)) | |
# update record metadata after training | |
for match in matches: | |
st.session_state.meta[match['id']] = 1 | |
logging.info(f"Exclude List: {st.session_state.meta}") | |
def delete_element(element): | |
del element | |
st.markdown(""" | |
<link | |
rel="stylesheet" | |
href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap" | |
/> | |
""", unsafe_allow_html=True) | |
messages = [ | |
f""" | |
Find most relevant examples from a large visual dataset by combining text query and few-shot learning. | |
""", | |
f""" | |
Then then you can adjust the weight on each image. Those weights should **represent how much it | |
can meet your preference**. You can either choose the images that match your prompt or change | |
your mind. | |
You might notice that there is a iteration slide bar on the top of all retrieved images. This will | |
control the speed of changes on vectors. More **iterations** will change the vector faster while | |
lower values on **iterations** will make the retrieval smoother. | |
""", | |
f""" | |
This example will manage to train a classifier to distinguish between samples you want and samples | |
you don't want. By initializing the weight from prompt, you can get a good enough classifier to cluster | |
images you want to search. If you think the result is not as perfect as you expected, you can also | |
supervise the classifer with **Relevance** annotation. If you cannot see any difference in Top-K | |
Retrieved results, try to enlarge **Number of Iteration** | |
""", | |
# TODO @ fangruil: fill the link with our tech blog | |
f""" | |
The app uses the [MyScale](http://mqdb.page.moqi.ai/mqdb-docs/) to store and query images | |
using vector search. All images are sourced from the | |
[Unsplash Lite dataset](https://unsplash-datasets.s3.amazonaws.com/lite/latest/unsplash-research-dataset-lite-latest.zip) | |
and encoded using [OpenAI's CLIP](https://huggingface.co/openai/clip-vit-base-patch32). We explain how | |
it all works [here](). | |
""" | |
] | |
with st.spinner("Connecting DB..."): | |
st.session_state.meta, st.session_state.index = init_db() | |
with st.spinner("Loading Models..."): | |
# Initialize CLIP model | |
if 'xq' not in st.session_state: | |
tokenizer, clip = init_clip() | |
st.session_state.query_num = 0 | |
if 'xq' not in st.session_state: | |
# If it's a fresh start | |
if st.session_state.query_num < len(messages): | |
msg = messages[st.session_state.query_num] | |
else: | |
msg = messages[-1] | |
# Basic Layout | |
with st.container(): | |
st.title("Visual Dataset Explorer") | |
start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] | |
start[0].info(msg) | |
prompt = start[1].text_input("Prompt:", value="", placeholder="Examples: a photo of white dogs, cats in the snow, a house by the lake") | |
start[2].markdown( | |
'<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>', | |
unsafe_allow_html=True) | |
with start[3]: | |
col = st.columns(8) | |
prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0) | |
random_xq = col[7].button("Random", disabled=len(prompt) != 0) | |
if random_xq: | |
# Randomly pick a vector to query | |
xq, orig_xq = init_random_query() | |
st.session_state.xq = xq | |
st.session_state.orig_xq = orig_xq | |
_ = [elem.empty() for elem in start] | |
elif prompt_xq: | |
print(f"Input prompt is {prompt}") | |
# Tokenize the vectors | |
xq = prompt2vec(prompt) | |
st.session_state.xq = xq | |
st.session_state.orig_xq = xq | |
_ = [elem.empty() for elem in start] | |
if 'xq' in st.session_state: | |
# If it is not a fresh start | |
if st.session_state.query_num+1 < len(messages): | |
msg = messages[st.session_state.query_num+1] | |
else: | |
msg = messages[-1] | |
# initialize classifier | |
if 'clf' not in st.session_state: | |
st.session_state.clf = Classifier(st.session_state.xq) | |
# if we want to display images we end up here | |
st.info(msg) | |
# first retrieve images from pinecone | |
st.session_state.matches, st.session_state.top_k = get_top_k(st.session_state.clf.get_weights(), top_k=9) | |
with st.container(): | |
with st.sidebar: | |
with st.container(): | |
st.header("Top K Nearest in Database") | |
for i, k in enumerate(st.session_state.top_k): | |
url = k["url"] | |
url += "?q=75&fm=jpg&w=200&fit=max" | |
if k["id"] not in BAD_IDS: | |
disabled = False | |
else: | |
disable = True | |
dist = np.matmul(st.session_state.clf.get_weights() / np.linalg.norm(st.session_state.clf.get_weights()), | |
np.array(k["vector"]).T) | |
st.markdown(card_with_conf(i, dist, url), unsafe_allow_html=True) | |
# once retrieved, display them alongside checkboxes in a form | |
with st.form("batch", clear_on_submit=False): | |
st.session_state.iters = st.slider("Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2) | |
col = st.columns([1,9]) | |
col[0].form_submit_button("Train!", on_click=submit) | |
col[1].form_submit_button("Choose a new prompt", on_click=refresh_index) | |
# we have three columns in the form | |
cols = st.columns(3) | |
for i, match in enumerate(st.session_state.matches): | |
# find good url | |
url = match["url"] | |
url += "?q=75&fm=jpg&w=200&fit=max" | |
if match["id"] not in BAD_IDS: | |
disabled = False | |
else: | |
disable = True | |
# the card shows an image and a checkbox | |
cols[i%3].markdown(card(i, url), unsafe_allow_html=True) | |
# we access the values of the checkbox via st.session_state[f"input{i}"] | |
cols[i%3].slider( | |
"Relevance", | |
min_value=0.0, | |
max_value=1.0, | |
value=1.0, | |
step=0.05, | |
key=f"input{i}", | |
disabled=disabled | |
) |