Spaces:
Sleeping
Sleeping
import base64 | |
import io | |
import os | |
import random | |
import gradio as gr | |
import numpy as np | |
import torch | |
from colpali_engine.models import ColPali, ColPaliProcessor | |
from datasets import load_dataset | |
from dotenv import load_dotenv | |
from PIL import Image, ImageDraw | |
from qdrant_client import QdrantClient | |
from qdrant_client.http import models | |
from tqdm import tqdm | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts, sizes | |
from typing import Iterable | |
# Load environment variables | |
load_dotenv() | |
# Set up device | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
print(f"Using device: {device}") | |
# Set up Qdrant client | |
QDRANT_URL = os.getenv("QDRANT_URL") | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
qdrant_client = QdrantClient( | |
url=QDRANT_URL, | |
api_key=QDRANT_API_KEY, | |
prefer_grpc=True, | |
) | |
# Load dataset and set up model | |
dataset = load_dataset("davanstrien/ufo-ColPali", split="train") | |
collection_name = "ufo" | |
model_name = "davanstrien/finetune_colpali_v1_2-ufo-4bit" | |
colpali_model = ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map=device, | |
) | |
colpali_processor = ColPaliProcessor.from_pretrained( | |
"vidore/colpaligemma-3b-pt-448-base" | |
) | |
def search_images_by_text(query_text, top_k=5): | |
with torch.no_grad(): | |
batch_query = colpali_processor.process_queries([query_text]).to( | |
colpali_model.device | |
) | |
query_embedding = colpali_model(**batch_query) | |
multivector_query = query_embedding[0].cpu().float().numpy().tolist() | |
results = qdrant_client.query_points( | |
collection_name=collection_name, | |
query=multivector_query, | |
limit=top_k, | |
timeout=60, | |
) | |
print(results) | |
return results | |
def search_by_text_and_return_images(query_text, top_k=5): | |
results = search_images_by_text(query_text, top_k) | |
print(results) | |
row_ids = [r.id for r in results.points] | |
subset = dataset.select(row_ids) | |
return list(subset["image"]) | |
class Geocities90s(Base): | |
def __init__( | |
self, | |
*, | |
primary_hue: colors.Color | str = colors.yellow, | |
secondary_hue: colors.Color | str = colors.purple, | |
neutral_hue: colors.Color | str = colors.gray, | |
font: fonts.Font | str = fonts.GoogleFont("Comic Neue"), | |
font_mono: fonts.Font | str = fonts.GoogleFont("VT323"), | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
secondary_hue=secondary_hue, | |
neutral_hue=neutral_hue, | |
font=(font, "Comic Sans MS", "ui-sans-serif", "sans-serif"), | |
font_mono=(font_mono, "Courier New", "monospace"), | |
) | |
self.set( | |
body_background_fill="url('https://web.archive.org/web/20091020152706/http://hk.geocities.com/neonlightfantasy/image/stars.gif')", | |
button_primary_background_fill="linear-gradient(90deg, *primary_500, *secondary_500)", | |
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *primary_500)", | |
button_primary_text_color="*neutral_50", | |
) | |
geocities90s = Geocities90s() | |
css = """ | |
body { | |
margin: 0; | |
padding: 0; | |
color: #00ff00; | |
font-family: 'Comic Sans MS', cursive; | |
} | |
.gradio-container { | |
background-image: url('https://i.ytimg.com/vi/5WapcCXEcXA/maxresdefault.jpg'); | |
background-repeat: repeat; | |
background-size: 300px 300px; | |
} | |
h1 { | |
text-align: center; | |
color: #ff00ff; | |
text-shadow: 2px 2px #000000; | |
font-size: 36px; | |
animation: flash 1s linear infinite; | |
} | |
@keyframes flash { | |
0% { color: #ff00ff; } | |
50% { color: #00ffff; } | |
100% { color: #ff00ff; } | |
} | |
.yellow-text { | |
color: #ffff00; | |
text-shadow: 2px 2px #000000; | |
font-weight: bold; | |
} | |
""" | |
# Replace the demo definition with this Blocks implementation | |
with gr.Blocks(css=css, theme=geocities90s) as demo: | |
gr.HTML("<h1>๐ธ Top Secret UFO Document Search ๐ธ</h1>") | |
gr.HTML( | |
"<p style='text-align: center; font-style: italic;'>Powered by <a href='https://danielvanstrien.xyz/posts/post-with-code/colpali-qdrant/2024-10-02_using_colpali_with_qdrant.html' target='_blank' style='color: #00ff00;'>ColPali and Qdrant</a></p>" | |
) | |
gr.HTML( | |
"<p style='text-align: center; color: #ff00ff;'>๐ฝ Discover how to build your own alien-approved search engine! Learn the secrets of ColPali and Qdrant, and join the ranks of interstellar code warriors. Warning: May attract Men in Black. ๐ด๏ธ๐ฝ</p>" | |
) | |
gr.HTML( | |
"<marquee direction='left' scrollamount='5' class='yellow-text'>Uncover the truth that's out there! The government doesn't want you to know! ColPali will reveal the truth!</marquee>" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
query_input = gr.Textbox( | |
label="Enter your cosmic query", | |
placeholder="e.g., alien abduction, crop circles", | |
) | |
with gr.Column(scale=1): | |
num_results = gr.Slider( | |
minimum=1, | |
maximum=10, | |
step=1, | |
label="Number of classified documents", | |
value=5, | |
) | |
search_button = gr.Button("Declassify Documents") | |
gallery_output = gr.Gallery(label="Declassified UFO Sightings", elem_id="gallery") | |
search_button.click( | |
fn=search_by_text_and_return_images, | |
inputs=[query_input, num_results], | |
outputs=gallery_output, | |
) | |
# Keep the main block | |
if __name__ == "__main__": | |
demo.launch() | |