yalaa's picture
Update app.py
e33ee91 verified
raw
history blame contribute delete
No virus
5.3 kB
import os
import gradio as gr
import pandas as pd
from download_dataset import get_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
from qdrant_client import QdrantClient
from search_helpers import (
retrieve_top_k,
rerank_hits,
fetch_top_article_with_passage_highlighted,
extract_sentence_and_partition,
)
SIMPLE_WIKI_PATH = 'simplewiki-2020-11-01.jsonl.gz'
COLLECTION_NAME = 'simplewiki'
RETRIEVAL_TOP_K = 40
DISPLAY_TOP_K = 10
dataset = get_dataset(SIMPLE_WIKI_PATH)
passages = dataset['passages']
articles = dataset['articles']
encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
qdrant = QdrantClient(
url=os.environ['QDRANT_URL'],
api_key=os.environ['QDRANT_API_KEY'],
)
print(qdrant.get_collections())
collections_names = list(map(lambda x: x.name, qdrant.get_collections().collections))
assert COLLECTION_NAME in collections_names
assert qdrant.get_collection(COLLECTION_NAME).vectors_count == 508000
def process_query(query):
original_hits, retrieval_time, embedding_time = retrieve_top_k(query, RETRIEVAL_TOP_K, vec_db=qdrant, encoder=encoder, collection_name=COLLECTION_NAME)
reranked_hits, reranking_time = rerank_hits(query, original_hits, cross_encoder=cross_encoder, articles=articles)
reranked_hits = reranked_hits[:DISPLAY_TOP_K]
df = pd.DataFrame(
{
"Retrieval Order": [value['retrieval_order'] for value in reranked_hits],
"Reranking Order": [value['reranked_order'] for value in reranked_hits],
"Title": [value['title'] for value in reranked_hits],
"Answer": [value['passage'] for value in reranked_hits],
"Article Text": [articles[value['article_id']]['content'] for value in reranked_hits],
}
)
return (
fetch_top_article_with_passage_highlighted(reranked_hits, articles=articles),
df,
{
"Embedding Time": str(round(embedding_time, 3)) + " s",
"Retrieval Time": str(round(retrieval_time, 3)) + " s",
"Reranking Time": str(round(reranking_time, 3)) + " s",
}
)
def update(selected_index: gr.SelectData, df):
val = df.iloc[selected_index.index[0]]
return extract_sentence_and_partition(val['Article Text'], val['Answer'])
with gr.Blocks() as retrieve_rerank_demo:
gr.Markdown(
"""
# Simple Wikipedia Semantic Search πŸ” Through Retrieval and Reranking
By inputing queries or questions, this space leverages machine learning to surface the most relevant Simple Wikipedia passages and articles, providing most relevant answers out of **{}** passages indexed on Qdrant cloud using binary quantization.
""".format(qdrant.get_collection(COLLECTION_NAME).vectors_count)
)
with gr.Accordion("Click to learn about the retreival process", open=False):
gr.Markdown(
"""
## Features
1. Encode all passages from Simple Wikipedia dataset into embeddings using a pretrained bi-encoder [`multi-qa-MiniLM-L6-cos-v1`](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1) from Sentence Transformers
2. Index the embeddings on `Qdrant` cloud using binary quantization for efficient retrieval, resulting in {} vector embeddings for encoded passages
3. The user enters a search query like a sentence or a questions
4. Encoding the user search query using the bi-encoder model
5. Retrieve the 40 most relevant passages to the input query by sifting through the indexed embeddings in the Qdrant collection and by leveraging binary quantization to boost retrieval speed
6. Rerank search results using a cross-encoder [`ms-marco-MiniLM-L-12-v2`](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) to priortize the most contextually relevant passages
7. Show the top article with the answer highlighted in green, the top 10 reranked answers in a DataFrame view, and the processing time required for both retrieval and reranking
""".format(qdrant.get_collection(COLLECTION_NAME).vectors_count)
)
input_question = gr.Textbox(
label="Query for Simple Wikipedia articles",
placeholder="Enter a query to search for relevant texts from Simple Wikipedia",
)
gr.Examples(
examples=[
["capital of united states"],
["pyramids of Egypt"],
["number of countries in Africa"],
["how many people live in alexandria"],
["where is the red sea?"]
],
inputs=[input_question]
)
button = gr.Button("Search πŸ”")
with gr.Accordion("Click to read the top article with answer highlighted", open=True):
highlighted_article_after_rerank = gr.HighlightedText(
value=[],
label="Top Article with Answer Highlighted",
color_map={'relevant passage': 'green'}
)
df_output = gr.Dataframe(
headers=[
"Retrieval Order",
"Reranking Order",
"Title",
"Answer",
"Article Text"
]
)
runtime_info = gr.Json()
button.click(
fn=process_query,
inputs=[
input_question,
],
outputs=[
highlighted_article_after_rerank,
df_output,
runtime_info
]
)
df_output.select(
fn=update,
inputs=df_output,
outputs=highlighted_article_after_rerank
)
retrieve_rerank_demo.launch(share=True)