import time import gradio as gr from datasets import load_dataset import pandas as pd from sentence_transformers import SentenceTransformer from sentence_transformers.util import quantize_embeddings import faiss from usearch.index import Index # Load titles and texts title_text_dataset = load_dataset("mixedbread-ai/wikipedia-2023-11-embed-en-pre-1", split="train").select_columns(["title", "text"]) # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it. int8_view = Index.restore("wikipedia_int8_usearch_1m.index", view=True) binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_1m.index") # Load the SentenceTransformer model for embedding the queries model = SentenceTransformer( "mixedbread-ai/mxbai-embed-large-v1", prompts={ "retrieval": "Represent this sentence for searching relevant passages: ", }, default_prompt_name="retrieval", ) def search(query, top_k: int = 10, rerank_multiplier: int = 4): # 1. Embed the query as float32 start_time = time.time() query_embedding = model.encode(query) embed_time = time.time() - start_time # 2. Quantize the query to ubinary start_time = time.time() query_embedding_ubinary = quantize_embeddings(query_embedding, "ubinary") quantize_time = time.time() - start_time # 3. Search the binary index start_time = time.time() _scores, binary_ids = binary_index.search(query_embedding_ubinary, top_k * rerank_multiplier) binary_ids = binary_ids[0] search_time = time.time() - start_time # 4. Load the corresponding int8 embeddings start_time = time.time() int8_embeddings = int8_view[binary_ids].astype(int) load_time = time.time() - start_time # 5. Rerank the top_k * rerank_multiplier using the float32 query embedding and the int8 document embeddings start_time = time.time() scores = query_embedding @ int8_embeddings.T rerank_time = time.time() - start_time # 6. Sort the scores and return the top_k start_time = time.time() top_k_indices = (-scores).argsort()[-top_k:] top_k_scores = scores[top_k_indices] top_k_titles, top_k_texts = zip(*[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in binary_ids[top_k_indices].tolist()]) df = pd.DataFrame({"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}) sort_time = time.time() - start_time return df, { "Embed Time": f"{embed_time:.4f} s", "Quantize Time": f"{quantize_time:.4f} s", "Search Time": f"{search_time:.4f} s", "Load Time": f"{load_time:.4f} s", "Rerank Time": f"{rerank_time:.4f} s", "Sort Time": f"{sort_time:.4f} s", "Total Retrieval Time": f"{quantize_time + search_time + load_time + rerank_time + sort_time:.4f} s" } with gr.Blocks(title="Quantized Retrieval") as demo: query = gr.Textbox(label="Query") search_button = gr.Button(value="Search") with gr.Row(): with gr.Column(scale=4): output = gr.Dataframe(column_widths=["10%", "20%", "80%"], headers=["Score", "Title", "Text"]) with gr.Column(scale=1): json = gr.JSON() search_button.click(search, inputs=[query], outputs=[output, json]) demo.queue() demo.launch(debug=True)