Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import pandas as pd | |
from download_dataset import get_dataset | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from qdrant_client import QdrantClient | |
from search_helpers import ( | |
retrieve_top_k, | |
rerank_hits, | |
fetch_top_article_with_passage_highlighted, | |
extract_sentence_and_partition, | |
) | |
SIMPLE_WIKI_PATH = 'simplewiki-2020-11-01.jsonl.gz' | |
COLLECTION_NAME = 'simplewiki' | |
RETRIEVAL_TOP_K = 40 | |
DISPLAY_TOP_K = 10 | |
dataset = get_dataset(SIMPLE_WIKI_PATH) | |
passages = dataset['passages'] | |
articles = dataset['articles'] | |
encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2') | |
qdrant = QdrantClient( | |
url=os.environ['QDRANT_URL'], | |
api_key=os.environ['QDRANT_API_KEY'], | |
) | |
print(qdrant.get_collections()) | |
collections_names = list(map(lambda x: x.name, qdrant.get_collections().collections)) | |
assert COLLECTION_NAME in collections_names | |
assert qdrant.get_collection(COLLECTION_NAME).vectors_count == 508000 | |
def process_query(query): | |
original_hits, retrieval_time, embedding_time = retrieve_top_k(query, RETRIEVAL_TOP_K, vec_db=qdrant, encoder=encoder, collection_name=COLLECTION_NAME) | |
reranked_hits, reranking_time = rerank_hits(query, original_hits, cross_encoder=cross_encoder, articles=articles) | |
reranked_hits = reranked_hits[:DISPLAY_TOP_K] | |
df = pd.DataFrame( | |
{ | |
"Retrieval Order": [value['retrieval_order'] for value in reranked_hits], | |
"Reranking Order": [value['reranked_order'] for value in reranked_hits], | |
"Title": [value['title'] for value in reranked_hits], | |
"Answer": [value['passage'] for value in reranked_hits], | |
"Article Text": [articles[value['article_id']]['content'] for value in reranked_hits], | |
} | |
) | |
return ( | |
fetch_top_article_with_passage_highlighted(reranked_hits, articles=articles), | |
df, | |
{ | |
"Embedding Time": str(round(embedding_time, 3)) + " s", | |
"Retrieval Time": str(round(retrieval_time, 3)) + " s", | |
"Reranking Time": str(round(reranking_time, 3)) + " s", | |
} | |
) | |
def update(selected_index: gr.SelectData, df): | |
val = df.iloc[selected_index.index[0]] | |
return extract_sentence_and_partition(val['Article Text'], val['Answer']) | |
with gr.Blocks() as retrieve_rerank_demo: | |
gr.Markdown( | |
""" | |
# Simple Wikipedia Semantic Search π Through Retrieval and Reranking | |
By inputing queries or questions, this space leverages machine learning to surface the most relevant Simple Wikipedia passages and articles, providing most relevant answers out of **{}** passages indexed on Qdrant cloud using binary quantization. | |
""".format(qdrant.get_collection(COLLECTION_NAME).vectors_count) | |
) | |
with gr.Accordion("Click to learn about the retreival process", open=False): | |
gr.Markdown( | |
""" | |
## Features | |
1. Encode all passages from Simple Wikipedia dataset into embeddings using a pretrained bi-encoder [`multi-qa-MiniLM-L6-cos-v1`](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1) from Sentence Transformers | |
2. Index the embeddings on `Qdrant` cloud using binary quantization for efficient retrieval, resulting in {} vector embeddings for encoded passages | |
3. The user enters a search query like a sentence or a questions | |
4. Encoding the user search query using the bi-encoder model | |
5. Retrieve the 40 most relevant passages to the input query by sifting through the indexed embeddings in the Qdrant collection and by leveraging binary quantization to boost retrieval speed | |
6. Rerank search results using a cross-encoder [`ms-marco-MiniLM-L-12-v2`](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) to priortize the most contextually relevant passages | |
7. Show the top article with the answer highlighted in green, the top 10 reranked answers in a DataFrame view, and the processing time required for both retrieval and reranking | |
""".format(qdrant.get_collection(COLLECTION_NAME).vectors_count) | |
) | |
input_question = gr.Textbox( | |
label="Query for Simple Wikipedia articles", | |
placeholder="Enter a query to search for relevant texts from Simple Wikipedia", | |
) | |
gr.Examples( | |
examples=[ | |
["capital of united states"], | |
["pyramids of Egypt"], | |
["number of countries in Africa"], | |
["how many people live in alexandria"], | |
["where is the red sea?"] | |
], | |
inputs=[input_question] | |
) | |
button = gr.Button("Search π") | |
with gr.Accordion("Click to read the top article with answer highlighted", open=True): | |
highlighted_article_after_rerank = gr.HighlightedText( | |
value=[], | |
label="Top Article with Answer Highlighted", | |
color_map={'relevant passage': 'green'} | |
) | |
df_output = gr.Dataframe( | |
headers=[ | |
"Retrieval Order", | |
"Reranking Order", | |
"Title", | |
"Answer", | |
"Article Text" | |
] | |
) | |
runtime_info = gr.Json() | |
button.click( | |
fn=process_query, | |
inputs=[ | |
input_question, | |
], | |
outputs=[ | |
highlighted_article_after_rerank, | |
df_output, | |
runtime_info | |
] | |
) | |
df_output.select( | |
fn=update, | |
inputs=df_output, | |
outputs=highlighted_article_after_rerank | |
) | |
retrieve_rerank_demo.launch(share=True) |