zotero-fasthtml / app.py
rbiswasfc's picture
change py version
8fb3197
raw
history blame
7.74 kB
import json
import os
from datetime import datetime
from typing import ClassVar
# import dotenv
import lancedb
import srsly
from fasthtml.common import * # noqa
from fasthtml_hf import setup_hf_backup
from huggingface_hub import snapshot_download
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import register
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, ColbertReranker
from lancedb.util import attempt_import_or_raise
# dotenv.load_dotenv()
# download the zotero index (~1200 papers as of July 24, currently hosted on HF) ----
def download_data():
snapshot_download(
repo_id="rbiswasfc/zotero_db",
repo_type="dataset",
local_dir="./data",
token=os.environ["HF_TOKEN"],
)
print("Data downloaded!")
if not os.path.exists(
"./data/.lancedb_zotero_v0"
): # TODO: implement a better check / refresh mechanism
download_data()
# cohere embedding utils ----
@register("coherev3")
class CohereEmbeddingFunction_2(TextEmbeddingFunction):
name: str = "embed-english-v3.0"
client: ClassVar = None
def ndims(self):
return 768
def generate_embeddings(self, texts):
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction_2.client.embed(
texts=texts, model=self.name, input_type="search_document"
)
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = attempt_import_or_raise("cohere")
if CohereEmbeddingFunction_2.client is None:
CohereEmbeddingFunction_2.client = cohere.Client(
os.environ["COHERE_API_KEY"]
)
COHERE_EMBEDDER = CohereEmbeddingFunction_2.create()
# LanceDB model ----
class ArxivModel(LanceModel):
text: str = COHERE_EMBEDDER.SourceField()
vector: Vector(1024) = COHERE_EMBEDDER.VectorField()
title: str
paper_title: str
content_type: str
arxiv_id: str
VERSION = "0.0.0a"
DB = lancedb.connect("./data/.lancedb_zotero_v0")
ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json")
RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()}
TBL = DB.open_table("arxiv_zotero_v0")
# format results ----
def _format_results(arxiv_refs):
results = []
for arx_id, paper_title in arxiv_refs.items():
abstract = ID_TO_ABSTRACT.get(arx_id, "")
# these are all ugly hacks because the data preprocessing is poor. to be fixed v soon.
if "Abstract\n\n" in abstract:
abstract = abstract.split("Abstract\n\n")[-1]
if paper_title in abstract:
abstract = abstract.split(paper_title)[-1]
if abstract.startswith("\n"):
abstract = abstract[1:]
if "\n\n" in abstract[:20]:
abstract = "\n\n".join(abstract.split("\n\n")[1:])
result = {
"title": paper_title,
"url": f"https://arxiv.org/abs/{arx_id}",
"abstract": abstract,
}
results.append(result)
return results
# Search logic ----
def query_db(query: str, k: int = 10, reranker: str = "cohere"):
raw_results = TBL.search(query, query_type="hybrid").limit(k)
if reranker is not None:
ranked_results = raw_results.rerank(reranker=RERANKERS[reranker])
else:
ranked_results = raw_results
ranked_results = ranked_results.to_pandas()
top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"})
top_results = top_results.sort_values(by="_relevance_score", ascending=False).head(
3
)
top_results_dict = {
row["arxiv_id"]: row["paper_title"]
for index, row in ranked_results.iterrows()
if row["arxiv_id"] in top_results.index
}
final_results = _format_results(top_results_dict)
return final_results
###########################################################################
# FastHTML app -----
###########################################################################
style = Style(
"""
:root {
color-scheme: dark;
}
body {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
line-height: 1.6;
}
#query {
width: 100%;
margin-bottom: 1rem;
}
#search-form button {
width: 100%;
}
#search-results, #log-entries {
margin-top: 2rem;
}
.log-entry {
border: 1px solid #ccc;
padding: 10px;
margin-bottom: 10px;
}
.log-entry pre {
white-space: pre-wrap;
word-wrap: break-word;
}
"""
)
# get the fast app and route
app, rt = fast_app(live=True, hdrs=(style,))
# Initialize a database to store search logs --
db = database("log_data/search_logs.db")
search_logs = db.t.search_logs
if search_logs not in db.t:
search_logs.create(
id=int,
timestamp=str,
query=str,
results=str,
pk="id",
)
SearchLog = search_logs.dataclass()
def insert_log_entry(log_entry):
"Insert a log entry into the database"
return search_logs.insert(
SearchLog(
timestamp=log_entry["timestamp"].isoformat(),
query=log_entry["query"],
results=json.dumps(log_entry["results"]),
)
)
@rt("/")
async def get():
query_form = Form(
Textarea(id="query", name="query", placeholder="Enter your query..."),
Button("Submit", type="submit"),
id="search-form",
hx_post="/search",
hx_target="#search-results",
)
# results_div = Div(H2("Search Results"), Div(id="search-results", cls="results-container"))
results_div = Div(Div(id="search-results", cls="results-container"))
view_logs_link = A("View Logs", href="/logs", cls="view-logs-link")
return Titled(
"Zotero Search", Div(query_form, results_div, view_logs_link, cls="container")
)
def SearchResult(result):
"Custom component for displaying a search result"
return Card(
H4(A(result["title"], href=result["url"], target="_blank")),
P(result["abstract"]),
footer=A("Read more →", href=result["url"], target="_blank"),
)
def log_query_and_results(query, results):
log_entry = {
"timestamp": datetime.now(),
"query": query,
"results": [{"title": r["title"], "url": r["url"]} for r in results],
}
insert_log_entry(log_entry)
@rt("/search")
async def post(query: str):
results = query_db(query)
log_query_and_results(query, results)
return Div(*[SearchResult(r) for r in results], id="search-results")
def LogEntry(entry):
return Div(
H4(f"Query: {entry.query}"),
P(f"Timestamp: {entry.timestamp}"),
H5("Results:"),
Pre(entry.results),
cls="log-entry",
)
@rt("/logs")
async def get():
logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs
log_entries = [LogEntry(log) for log in logs]
return Titled(
"Logs",
Div(
H2("Recent Search Logs"),
Div(*log_entries, id="log-entries"),
A("Back to Search", href="/", cls="back-link"),
cls="container",
),
)
if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
setup_hf_backup(app)
run_uv()