test-bm25 / app.py
kwang2049's picture
force gradio version
d163285
raw
history blame
1.79 kB
import gradio as gr
from typing import Dict, List, Optional, TypedDict
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
from bm25 import BM25Index, BM25Retriever
import os
os.system("pip install gradio==4.44.0")
sciq = load_sciq()
bm25_index = BM25Index.build_from_documents(
documents=iter(sciq.corpus),
ndocs=12160,
show_progress_bar=True,
k1=0.8, # Tuned on dev wrt. MAP@10
b=0.6, # Tuned on dev wrt. MAP@10
)
bm25_index.save("output/bm25_sciq_index")
bm25_retriever = BM25Retriever(index_dir="output/bm25_sciq_index")
class Hit(TypedDict):
cid: str
score: float
text: str
demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
return_type = List[Hit]
## YOUR_CODE_STARTS_HERE
cid2doc = {doc.collection_id: doc.text for doc in sciq.corpus}
def search(query: str) -> List[Hit]:
ranking: Dict[str, float] = bm25_retriever.retrieve(query)
# Sort the ranking by score in descending order
sorted_ranking = sorted(ranking.items(), key=lambda item: item[1], reverse=True)
hits = []
for cid, score in sorted_ranking:
hits.append(Hit(cid=cid, score=score, text=cid2doc[cid]))
return hits
demo = gr.Interface(
fn=search,
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
outputs="text",
title="BM25 Retriever Search",
description="Search using a BM25 retriever on [SciQ](https://huggingface.co/datasets/allenai/sciq) and return top-10 ranked documents with scores.",
)
## YOUR_CODE_ENDS_HERE
# print(demo.local_url)
demo.launch()
# start a thread to run the demo
# import threading
# thread = threading.Thread(target=demo.launch)
# thread.start()
# import time
# time.sleep(5)
# print(demo.local_url)
# print(demo.local_api_url)
# thread.join()