davanstrien's picture
davanstrien HF staff
chore: Add UFO search engine details to app.py
46a780a
import base64
import io
import os
import random
import gradio as gr
import numpy as np
import torch
from colpali_engine.models import ColPali, ColPaliProcessor
from datasets import load_dataset
from dotenv import load_dotenv
from PIL import Image, ImageDraw
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
from typing import Iterable
# Load environment variables
load_dotenv()
# Set up device
if torch.cuda.is_available():
device = "cuda:0"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
# Set up Qdrant client
QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
qdrant_client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True,
)
# Load dataset and set up model
dataset = load_dataset("davanstrien/ufo-ColPali", split="train")
collection_name = "ufo"
model_name = "davanstrien/finetune_colpali_v1_2-ufo-4bit"
colpali_model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
)
colpali_processor = ColPaliProcessor.from_pretrained(
"vidore/colpaligemma-3b-pt-448-base"
)
def search_images_by_text(query_text, top_k=5):
with torch.no_grad():
batch_query = colpali_processor.process_queries([query_text]).to(
colpali_model.device
)
query_embedding = colpali_model(**batch_query)
multivector_query = query_embedding[0].cpu().float().numpy().tolist()
results = qdrant_client.query_points(
collection_name=collection_name,
query=multivector_query,
limit=top_k,
timeout=60,
)
print(results)
return results
def search_by_text_and_return_images(query_text, top_k=5):
results = search_images_by_text(query_text, top_k)
print(results)
row_ids = [r.id for r in results.points]
subset = dataset.select(row_ids)
return list(subset["image"])
class Geocities90s(Base):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.yellow,
secondary_hue: colors.Color | str = colors.purple,
neutral_hue: colors.Color | str = colors.gray,
font: fonts.Font | str = fonts.GoogleFont("Comic Neue"),
font_mono: fonts.Font | str = fonts.GoogleFont("VT323"),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
font=(font, "Comic Sans MS", "ui-sans-serif", "sans-serif"),
font_mono=(font_mono, "Courier New", "monospace"),
)
self.set(
body_background_fill="url('https://web.archive.org/web/20091020152706/http://hk.geocities.com/neonlightfantasy/image/stars.gif')",
button_primary_background_fill="linear-gradient(90deg, *primary_500, *secondary_500)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *primary_500)",
button_primary_text_color="*neutral_50",
)
geocities90s = Geocities90s()
css = """
body {
margin: 0;
padding: 0;
color: #00ff00;
font-family: 'Comic Sans MS', cursive;
}
.gradio-container {
background-image: url('https://i.ytimg.com/vi/5WapcCXEcXA/maxresdefault.jpg');
background-repeat: repeat;
background-size: 300px 300px;
}
h1 {
text-align: center;
color: #ff00ff;
text-shadow: 2px 2px #000000;
font-size: 36px;
animation: flash 1s linear infinite;
}
@keyframes flash {
0% { color: #ff00ff; }
50% { color: #00ffff; }
100% { color: #ff00ff; }
}
.yellow-text {
color: #ffff00;
text-shadow: 2px 2px #000000;
font-weight: bold;
}
"""
# Replace the demo definition with this Blocks implementation
with gr.Blocks(css=css, theme=geocities90s) as demo:
gr.HTML("<h1>๐Ÿ›ธ Top Secret UFO Document Search ๐Ÿ›ธ</h1>")
gr.HTML(
"<p style='text-align: center; font-style: italic;'>Powered by <a href='https://danielvanstrien.xyz/posts/post-with-code/colpali-qdrant/2024-10-02_using_colpali_with_qdrant.html' target='_blank' style='color: #00ff00;'>ColPali and Qdrant</a></p>"
)
gr.HTML(
"<p style='text-align: center; color: #ff00ff;'>๐Ÿ‘ฝ Discover how to build your own alien-approved search engine! Learn the secrets of ColPali and Qdrant, and join the ranks of interstellar code warriors. Warning: May attract Men in Black. ๐Ÿ•ด๏ธ๐Ÿ‘ฝ</p>"
)
gr.HTML(
"<marquee direction='left' scrollamount='5' class='yellow-text'>Uncover the truth that's out there! The government doesn't want you to know! ColPali will reveal the truth!</marquee>"
)
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Enter your cosmic query",
placeholder="e.g., alien abduction, crop circles",
)
with gr.Column(scale=1):
num_results = gr.Slider(
minimum=1,
maximum=10,
step=1,
label="Number of classified documents",
value=5,
)
search_button = gr.Button("Declassify Documents")
gallery_output = gr.Gallery(label="Declassified UFO Sightings", elem_id="gallery")
search_button.click(
fn=search_by_text_and_return_images,
inputs=[query_input, num_results],
outputs=gallery_output,
)
# Keep the main block
if __name__ == "__main__":
demo.launch()