davanstrien HF staff commited on
Commit
b36a913
1 Parent(s): 5fd9c92
Files changed (1) hide show
  1. app.py +147 -110
app.py CHANGED
@@ -1,122 +1,159 @@
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
 
3
  from datasets import load_dataset
 
 
4
  from qdrant_client import QdrantClient
5
  from qdrant_client.http import models
6
- from colpali_engine.models import ColQwen2, ColQwen2Processor
7
- from PIL import Image
8
- import requests
9
- from io import BytesIO
10
-
11
- # Initialize the model, processor, and Qdrant client
12
- model_name = "vidore/colqwen2-v0.1"
13
- colpali_model = ColQwen2.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cuda:0")
14
- colpali_processor = ColQwen2Processor.from_pretrained(model_name)
15
- qdrant_client = QdrantClient(":memory:")
16
- collection_name = "image_collection"
17
-
18
- # Load the dataset (this should be done only once when setting up the app)
19
- dataset = load_dataset("davanstrien/loc-nineteenth-century-song-sheets", split="train")
20
-
21
- def setup_qdrant():
22
- # Create a collection in Qdrant
23
- qdrant_client.recreate_collection(
24
- collection_name=collection_name,
25
- vectors_config=models.VectorParams(
26
- size=colpali_model.config.hidden_size,
27
- distance=models.Distance.COSINE,
28
- multivector_config=models.MultiVectorConfig(
29
- comparator=models.MultiVectorComparator.MAX_SIM
30
- ),
31
- ),
32
- )
33
-
34
- # Index the dataset (this should be done only once when setting up the app)
35
- batch_size = 32
36
- for i in range(0, len(dataset), batch_size):
37
- batch = dataset[i:i+batch_size]
38
- images = batch['image']
39
- with torch.no_grad():
40
- batch_images = colpali_processor.process_images(images).to(colpali_model.device)
41
- image_embeddings = colpali_model(**batch_images)
42
-
43
- points = []
44
- for j, embedding in enumerate(image_embeddings):
45
- multivector = embedding.cpu().float().numpy().tolist()
46
- points.append(models.PointStruct(
47
- id=i+j,
48
- vector=multivector,
49
- payload={
50
- "item_id": batch['item_id'][j],
51
- "item_url": batch['item_url'][j]
52
- }
53
- ))
54
-
55
- qdrant_client.upsert(
56
- collection_name=collection_name,
57
- points=points
58
- )
59
-
60
- print("Indexing complete!")
61
-
62
- def search_similar_images(query, top_k=5, mode="text"):
63
  with torch.no_grad():
64
- if mode == "text":
65
- batch_query = colpali_processor.process_queries([query]).to(colpali_model.device)
66
- else: # Image mode
67
- batch_query = colpali_processor.process_images([query]).to(colpali_model.device)
68
  query_embedding = colpali_model(**batch_query)
69
 
70
  multivector_query = query_embedding[0].cpu().float().numpy().tolist()
71
-
72
- search_result = qdrant_client.search(
73
  collection_name=collection_name,
74
- query_vector=multivector_query,
75
- limit=top_k
 
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- return search_result
79
-
80
- def process_results(results):
81
- output = []
82
- for result in results:
83
- item_url = result.payload['item_url']
84
- score = result.score
85
- output.append((item_url, f"Score: {score:.4f}"))
86
- return output
87
-
88
- def text_search(query, top_k):
89
- results = search_similar_images(query, top_k, mode="text")
90
- return process_results(results)
91
-
92
- def image_search(image, top_k):
93
- results = search_similar_images(image, top_k, mode="image")
94
- return process_results(results)
95
-
96
- # Set up the Gradio interface
97
- with gr.Blocks() as demo:
98
- gr.Markdown("# Image Search App")
99
- gr.Markdown("Search for similar images using text or image input.")
100
-
101
- with gr.Tab("Text Search"):
102
- text_input = gr.Textbox(label="Enter your search query")
103
- text_button = gr.Button("Search")
104
- text_output = gr.Gallery(label="Results", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto")
105
- text_scores = gr.JSON(label="Scores")
106
-
107
- with gr.Tab("Image Search"):
108
- image_input = gr.Image(type="pil", label="Upload an image")
109
- image_button = gr.Button("Search")
110
- image_output = gr.Gallery(label="Results", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto")
111
- image_scores = gr.JSON(label="Scores")
112
-
113
- top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results")
114
-
115
- text_button.click(text_search, inputs=[text_input, top_k_slider], outputs=[text_output, text_scores])
116
- image_button.click(image_search, inputs=[image_input, top_k_slider], outputs=[image_output, image_scores])
117
-
118
- # Run the setup (this should be done only once when deploying the app)
119
- setup_qdrant()
120
-
121
- # Launch the app
122
- demo.launch()
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import random
5
+
6
  import gradio as gr
7
+ import numpy as np
8
  import torch
9
+ from colpali_engine.models import ColPali, ColPaliProcessor
10
  from datasets import load_dataset
11
+ from dotenv import load_dotenv
12
+ from PIL import Image, ImageDraw
13
  from qdrant_client import QdrantClient
14
  from qdrant_client.http import models
15
+ from tqdm import tqdm
16
+ from gradio.themes.base import Base
17
+ from gradio.themes.utils import colors, fonts, sizes
18
+ from typing import Iterable
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+
23
+ # Set up device
24
+ if torch.cuda.is_available():
25
+ device = "cuda:0"
26
+ elif torch.backends.mps.is_available():
27
+ device = "mps"
28
+ else:
29
+ device = "cpu"
30
+ print(f"Using device: {device}")
31
+
32
+ # Set up Qdrant client
33
+ QDRANT_URL = os.getenv("QDRANT_URL")
34
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
35
+ qdrant_client = QdrantClient(
36
+ url=QDRANT_URL,
37
+ api_key=QDRANT_API_KEY,
38
+ prefer_grpc=True,
39
+ )
40
+
41
+
42
+ # Load dataset and set up model
43
+ dataset = load_dataset("davanstrien/ufo-ColPali", split="train")
44
+ collection_name = "ufo"
45
+
46
+ model_name = "davanstrien/finetune_colpali_v1_2-ufo-4bit"
47
+ colpali_model = ColPali.from_pretrained(
48
+ model_name,
49
+ torch_dtype=torch.bfloat16,
50
+ device_map=device,
51
+ )
52
+ colpali_processor = ColPaliProcessor.from_pretrained(
53
+ "vidore/colpaligemma-3b-pt-448-base"
54
+ )
55
+
56
+
57
+ def search_images_by_text(query_text, top_k=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  with torch.no_grad():
59
+ batch_query = colpali_processor.process_queries([query_text]).to(
60
+ colpali_model.device
61
+ )
 
62
  query_embedding = colpali_model(**batch_query)
63
 
64
  multivector_query = query_embedding[0].cpu().float().numpy().tolist()
65
+ results = qdrant_client.query_points(
 
66
  collection_name=collection_name,
67
+ query=multivector_query,
68
+ limit=top_k,
69
+ timeout=60,
70
  )
71
+ print(results)
72
+ return results
73
+
74
+
75
+ def search_by_text_and_return_images(query_text, top_k=5):
76
+ results = search_images_by_text(query_text, top_k)
77
+ print(results)
78
+ row_ids = [r.id for r in results.points]
79
+ subset = dataset.select(row_ids)
80
+ return list(subset["image"])
81
+
82
+
83
+ class Geocities90s(Base):
84
+ def __init__(
85
+ self,
86
+ *,
87
+ primary_hue: colors.Color | str = colors.yellow,
88
+ secondary_hue: colors.Color | str = colors.purple,
89
+ neutral_hue: colors.Color | str = colors.gray,
90
+ font: fonts.Font | str = fonts.GoogleFont("Comic Neue"),
91
+ font_mono: fonts.Font | str = fonts.GoogleFont("VT323"),
92
+ ):
93
+ super().__init__(
94
+ primary_hue=primary_hue,
95
+ secondary_hue=secondary_hue,
96
+ neutral_hue=neutral_hue,
97
+ font=(font, "Comic Sans MS", "ui-sans-serif", "sans-serif"),
98
+ font_mono=(font_mono, "Courier New", "monospace"),
99
+ )
100
+ self.set(
101
+ body_background_fill="url('https://web.archive.org/web/20091020152706/http://hk.geocities.com/neonlightfantasy/image/stars.gif')",
102
+ button_primary_background_fill="linear-gradient(90deg, *primary_500, *secondary_500)",
103
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_500, *primary_500)",
104
+ button_primary_text_color="*neutral_50",
105
+ )
106
 
107
+
108
+ geocities90s = Geocities90s()
109
+
110
+ css = """
111
+ body {
112
+ margin: 0;
113
+ padding: 0;
114
+ color: #00ff00;
115
+ font-family: 'Comic Sans MS', cursive;
116
+ }
117
+ .gradio-container {
118
+ background-image: url('https://i.ytimg.com/vi/5WapcCXEcXA/maxresdefault.jpg');
119
+ background-repeat: repeat;
120
+ background-size: 300px 300px;
121
+ }
122
+ h1 {
123
+ text-align: center;
124
+ color: #ff00ff;
125
+ text-shadow: 2px 2px #000000;
126
+ font-size: 36px;
127
+ }
128
+ .yellow-text {
129
+ color: #ffff00;
130
+ text-shadow: 2px 2px #000000;
131
+ font-weight: bold;
132
+ }
133
+ """
134
+
135
+ demo = gr.Interface(
136
+ fn=search_by_text_and_return_images,
137
+ inputs=[
138
+ gr.Textbox(
139
+ label="Enter your cosmic query",
140
+ placeholder="e.g., alien abduction, crop circles",
141
+ ),
142
+ gr.Slider(
143
+ minimum=1,
144
+ maximum=10,
145
+ step=1,
146
+ label="Number of classified documents",
147
+ value=5,
148
+ ),
149
+ ],
150
+ outputs=gr.Gallery(label="Declassified UFO Sightings", elem_id="gallery"),
151
+ title="🛸 Top Secret UFO Document Search 🛸",
152
+ description="<marquee direction='left' scrollamount='5' class='yellow-text'>Uncover the truth that's out there! The government doesn't want you to know!</marquee>",
153
+ css=css,
154
+ allow_flagging="never",
155
+ theme=geocities90s,
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ demo.launch()