import base64 import io import os import random import gradio as gr import numpy as np import torch from colpali_engine.models import ColPali, ColPaliProcessor from datasets import load_dataset from dotenv import load_dotenv from PIL import Image, ImageDraw from qdrant_client import QdrantClient from qdrant_client.http import models from tqdm import tqdm from gradio.themes.base import Base from gradio.themes.utils import colors, fonts, sizes from typing import Iterable # Load environment variables load_dotenv() # Set up device if torch.cuda.is_available(): device = "cuda:0" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print(f"Using device: {device}") # Set up Qdrant client QDRANT_URL = os.getenv("QDRANT_URL") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") qdrant_client = QdrantClient( url=QDRANT_URL, api_key=QDRANT_API_KEY, prefer_grpc=True, ) # Load dataset and set up model dataset = load_dataset("davanstrien/ufo-ColPali", split="train") collection_name = "ufo" model_name = "davanstrien/finetune_colpali_v1_2-ufo-4bit" colpali_model = ColPali.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device, ) colpali_processor = ColPaliProcessor.from_pretrained( "vidore/colpaligemma-3b-pt-448-base" ) def search_images_by_text(query_text, top_k=5): with torch.no_grad(): batch_query = colpali_processor.process_queries([query_text]).to( colpali_model.device ) query_embedding = colpali_model(**batch_query) multivector_query = query_embedding[0].cpu().float().numpy().tolist() results = qdrant_client.query_points( collection_name=collection_name, query=multivector_query, limit=top_k, timeout=60, ) print(results) return results def search_by_text_and_return_images(query_text, top_k=5): results = search_images_by_text(query_text, top_k) print(results) row_ids = [r.id for r in results.points] subset = dataset.select(row_ids) return list(subset["image"]) class Geocities90s(Base): def __init__( self, *, primary_hue: colors.Color | str = colors.yellow, secondary_hue: colors.Color | str = colors.purple, neutral_hue: colors.Color | str = colors.gray, font: fonts.Font | str = fonts.GoogleFont("Comic Neue"), font_mono: fonts.Font | str = fonts.GoogleFont("VT323"), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, font=(font, "Comic Sans MS", "ui-sans-serif", "sans-serif"), font_mono=(font_mono, "Courier New", "monospace"), ) self.set( body_background_fill="url('https://web.archive.org/web/20091020152706/http://hk.geocities.com/neonlightfantasy/image/stars.gif')", button_primary_background_fill="linear-gradient(90deg, *primary_500, *secondary_500)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *primary_500)", button_primary_text_color="*neutral_50", ) geocities90s = Geocities90s() css = """ body { margin: 0; padding: 0; color: #00ff00; font-family: 'Comic Sans MS', cursive; } .gradio-container { background-image: url('https://i.ytimg.com/vi/5WapcCXEcXA/maxresdefault.jpg'); background-repeat: repeat; background-size: 300px 300px; } h1 { text-align: center; color: #ff00ff; text-shadow: 2px 2px #000000; font-size: 36px; } .yellow-text { color: #ffff00; text-shadow: 2px 2px #000000; font-weight: bold; } """ demo = gr.Interface( fn=search_by_text_and_return_images, inputs=[ gr.Textbox( label="Enter your cosmic query", placeholder="e.g., alien abduction, crop circles", ), gr.Slider( minimum=1, maximum=10, step=1, label="Number of classified documents", value=5, ), ], outputs=gr.Gallery(label="Declassified UFO Sightings", elem_id="gallery"), title="🛸 Top Secret UFO Document Search 🛸", description="Uncover the truth that's out there! The government doesn't want you to know!", css=css, allow_flagging="never", theme=geocities90s, ) if __name__ == "__main__": demo.launch()