opensearchspace / app.py
suzhoum's picture
wip
3b2a25c
import gradio as gr
import ir_datasets
import pandas as pd
import numpy as np
from autogluon.multimodal import MultiModalPredictor
query_embedding = None
document_embedding = None
docs_df = None
def text_embedding_batch():
global query_embedding
global docs_df
model_name = "sentence-transformers/all-MiniLM-L6-v2"
dataset = ir_datasets.load("beir/fiqa/dev")
docs_df = pd.DataFrame(dataset.docs_iter()).set_index("doc_id").sample(frac=0.0001)
predictor = MultiModalPredictor(
pipeline="feature_extraction",
hyperparameters={
"model.hf_text.checkpoint_name": model_name
}
)
embedding = predictor.extract_embedding(docs_df)
query_embedding = embedding["text"]
return query_embedding
def text_embedding_single(query: str):
global document_embedding
model_name = "sentence-transformers/all-MiniLM-L6-v2"
predictor = MultiModalPredictor(
pipeline="feature_extraction",
hyperparameters={
"model.hf_text.checkpoint_name": model_name
}
)
embedding = predictor.extract_embedding([query])
document_embedding = embedding["0"]
return document_embedding
def rank_document():
global query_embedding
global document_embedding
global docs_df
print('~~~~~here')
print('~~~~~~~~', query_embedding, document_embedding)
q_norm = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
print(q_norm)
d_norm = document_embedding / np.linalg.norm(document_embedding, axis=-1, keepdims=True)
scores = d_norm.dot(q_norm[0])
print(scores)
result = []
for idx in np.argsort(-scores)[:2]:
result.append(docs_df['text'].iloc[idx])
return result
def main():
with gr.Blocks(title="OpenSearch Demo") as demo:
gr.Markdown("# Semantic Search with Autogluon")
gr.Markdown("Ask an open question!")
with gr.Row():
inp_single = gr.Textbox(show_label=False)
with gr.Row():
btn_single = gr.Button("Generate Embedding")
with gr.Row():
out_single = gr.DataFrame(label="Embedding", show_label=True)
gr.Markdown("You can select one of the sample datasets for document embedding")
with gr.Row():
btn_fiqa = gr.Button("fiqa")
with gr.Row():
out_batch = gr.DataFrame(label="Sample Embeddings", show_label=True, row_count=5)
gr.Markdown("Now rank the documents and pick the top 3 most relevant from the dataset")
with gr.Row():
btn_rank = gr.Button("Rank documents")
with gr.Row():
out_rank = gr.DataFrame(label="Top ranked documents", show_label=True, row_count=5)
# with gr.Row():
# out_batch = gr.File(interactive=True)
# with gr.Row():
# btn_file = gr.Button("Generate Embedding")
btn_single.click(fn=text_embedding_single, inputs=inp_single, outputs=out_single)
btn_fiqa.click(fn=text_embedding_batch, inputs=None, outputs=out_batch)
btn_rank.click(fn=rank_document, inputs=None, outputs=out_rank)
# btn_file.click(fn=text_embedding_batch, inputs=inp_single, outputs=out_single)
demo.launch()
if __name__ == "__main__":
main()