import os import gradio as gr import pandas as pd from download_dataset import get_dataset from sentence_transformers import SentenceTransformer, CrossEncoder from qdrant_client import QdrantClient from search_helpers import ( retrieve_top_k, rerank_hits, fetch_top_article_with_passage_highlighted, extract_sentence_and_partition, ) SIMPLE_WIKI_PATH = 'simplewiki-2020-11-01.jsonl.gz' COLLECTION_NAME = 'simplewiki' RETRIEVAL_TOP_K = 40 DISPLAY_TOP_K = 10 dataset = get_dataset(SIMPLE_WIKI_PATH) passages = dataset['passages'] articles = dataset['articles'] encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') qdrant = QdrantClient( url=os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_API_KEY'], ) print(qdrant.get_collections()) collections_names = list(map(lambda x: x.name, qdrant.get_collections().collections)) assert COLLECTION_NAME in collections_names assert qdrant.get_collection(COLLECTION_NAME).vectors_count == 508000 def process_query(query): original_hits, retrieval_time, embedding_time = retrieve_top_k(query, RETRIEVAL_TOP_K, vec_db=qdrant, encoder=encoder, collection_name=COLLECTION_NAME) reranked_hits, reranking_time = rerank_hits(query, original_hits, cross_encoder=cross_encoder, articles=articles) reranked_hits = reranked_hits[:DISPLAY_TOP_K] df = pd.DataFrame( { "Retrieval Order": [value['retrieval_order'] for value in reranked_hits], "Reranking Order": [value['reranked_order'] for value in reranked_hits], "Title": [value['title'] for value in reranked_hits], "Answer": [value['passage'] for value in reranked_hits], "Article Text": [articles[value['article_id']]['content'] for value in reranked_hits], } ) return ( fetch_top_article_with_passage_highlighted(reranked_hits, articles=articles), df, { "Embedding Time": str(round(embedding_time, 3)) + " s", "Retrieval Time": str(round(retrieval_time, 3)) + " s", "Reranking Time": str(round(reranking_time, 3)) + " s", } ) def update(selected_index: gr.SelectData, df): val = df.iloc[selected_index.index[0]] return extract_sentence_and_partition(val['Article Text'], val['Answer']) with gr.Blocks() as retrieve_rerank_demo: gr.Markdown( """ # Simple Wikipedia Semantic Search 🔍 Through Retrieval and Reranking By inputing queries or questions, this space leverages machine learning to surface the most relevant Simple Wikipedia passages and articles, providing most relevant answers out of **{}** passages indexed on Qdrant cloud using binary quantization. """.format(qdrant.get_collection(COLLECTION_NAME).vectors_count) ) with gr.Accordion("Click to learn about the retreival process", open=False): gr.Markdown( """ ## Features 1. Encode all passages from Simple Wikipedia dataset into embeddings using a pretrained bi-encoder [`multi-qa-MiniLM-L6-cos-v1`](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1) from Sentence Transformers 2. Index the embeddings on `Qdrant` cloud using binary quantization for efficient retrieval, resulting in {} vector embeddings for encoded passages 3. The user enters a search query like a sentence or a questions 4. Encoding the user search query using the bi-encoder model 5. Retrieve the 40 most relevant passages to the input query by sifting through the indexed embeddings in the Qdrant collection and by leveraging binary quantization to boost retrieval speed 6. Rerank search results using a cross-encoder [`ms-marco-MiniLM-L-12-v2`](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) to priortize the most contextually relevant passages 7. Show the top article with the answer highlighted in green, the top 10 reranked answers in a DataFrame view, and the processing time required for both retrieval and reranking """.format(qdrant.get_collection(COLLECTION_NAME).vectors_count) ) input_question = gr.Textbox( label="Query for Simple Wikipedia articles", placeholder="Enter a query to search for relevant texts from Simple Wikipedia", ) gr.Examples( examples=[ ["capital of united states"], ["pyramids of Egypt"], ["number of countries in Africa"], ["how many people live in alexandria"], ["where is the red sea?"] ], inputs=[input_question] ) button = gr.Button("Search 🔍") with gr.Accordion("Click to read the top article with answer highlighted", open=True): highlighted_article_after_rerank = gr.HighlightedText( value=[], label="Top Article with Answer Highlighted", color_map={'relevant passage': 'green'} ) df_output = gr.Dataframe( headers=[ "Retrieval Order", "Reranking Order", "Title", "Answer", "Article Text" ] ) runtime_info = gr.Json() button.click( fn=process_query, inputs=[ input_question, ], outputs=[ highlighted_article_after_rerank, df_output, runtime_info ] ) df_output.select( fn=update, inputs=df_output, outputs=highlighted_article_after_rerank ) retrieve_rerank_demo.launch(share=True)