Spaces:
Runtime error
Runtime error
import gradio as gr | |
import ir_datasets | |
import pandas as pd | |
import numpy as np | |
from autogluon.multimodal import MultiModalPredictor | |
query_embedding = None | |
document_embedding = None | |
docs_df = None | |
def text_embedding_batch(): | |
global query_embedding | |
global docs_df | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
dataset = ir_datasets.load("beir/fiqa/dev") | |
docs_df = pd.DataFrame(dataset.docs_iter()).set_index("doc_id").sample(frac=0.0001) | |
predictor = MultiModalPredictor( | |
pipeline="feature_extraction", | |
hyperparameters={ | |
"model.hf_text.checkpoint_name": model_name | |
} | |
) | |
embedding = predictor.extract_embedding(docs_df) | |
query_embedding = embedding["text"] | |
return query_embedding | |
def text_embedding_single(query: str): | |
global document_embedding | |
model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
predictor = MultiModalPredictor( | |
pipeline="feature_extraction", | |
hyperparameters={ | |
"model.hf_text.checkpoint_name": model_name | |
} | |
) | |
embedding = predictor.extract_embedding([query]) | |
document_embedding = embedding["0"] | |
return document_embedding | |
def rank_document(): | |
global query_embedding | |
global document_embedding | |
global docs_df | |
print('~~~~~here') | |
print('~~~~~~~~', query_embedding, document_embedding) | |
q_norm = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True) | |
print(q_norm) | |
d_norm = document_embedding / np.linalg.norm(document_embedding, axis=-1, keepdims=True) | |
scores = d_norm.dot(q_norm[0]) | |
print(scores) | |
result = [] | |
for idx in np.argsort(-scores)[:2]: | |
result.append(docs_df['text'].iloc[idx]) | |
return result | |
def main(): | |
with gr.Blocks(title="OpenSearch Demo") as demo: | |
gr.Markdown("# Semantic Search with Autogluon") | |
gr.Markdown("Ask an open question!") | |
with gr.Row(): | |
inp_single = gr.Textbox(show_label=False) | |
with gr.Row(): | |
btn_single = gr.Button("Generate Embedding") | |
with gr.Row(): | |
out_single = gr.DataFrame(label="Embedding", show_label=True) | |
gr.Markdown("You can select one of the sample datasets for document embedding") | |
with gr.Row(): | |
btn_fiqa = gr.Button("fiqa") | |
with gr.Row(): | |
out_batch = gr.DataFrame(label="Sample Embeddings", show_label=True, row_count=5) | |
gr.Markdown("Now rank the documents and pick the top 3 most relevant from the dataset") | |
with gr.Row(): | |
btn_rank = gr.Button("Rank documents") | |
with gr.Row(): | |
out_rank = gr.DataFrame(label="Top ranked documents", show_label=True, row_count=5) | |
# with gr.Row(): | |
# out_batch = gr.File(interactive=True) | |
# with gr.Row(): | |
# btn_file = gr.Button("Generate Embedding") | |
btn_single.click(fn=text_embedding_single, inputs=inp_single, outputs=out_single) | |
btn_fiqa.click(fn=text_embedding_batch, inputs=None, outputs=out_batch) | |
btn_rank.click(fn=rank_document, inputs=None, outputs=out_rank) | |
# btn_file.click(fn=text_embedding_batch, inputs=inp_single, outputs=out_single) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |