rbiswasfc commited on
Commit
43191f7
1 Parent(s): f9e2dff
Files changed (4) hide show
  1. Dockerfile +11 -0
  2. app.py +276 -0
  3. favicon.ico +0 -0
  4. requirements.txt +11 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY ./app.py /code/
10
+
11
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+ from typing import ClassVar
5
+
6
+ # import dotenv
7
+ import lancedb
8
+ import srsly
9
+ from fasthtml.common import * # noqa
10
+ from huggingface_hub import snapshot_download
11
+ from lancedb.embeddings.base import TextEmbeddingFunction
12
+ from lancedb.embeddings.registry import register
13
+ from lancedb.pydantic import LanceModel, Vector
14
+ from lancedb.rerankers import CohereReranker, ColbertReranker
15
+ from lancedb.util import attempt_import_or_raise
16
+
17
+ # dotenv.load_dotenv()
18
+
19
+
20
+ # download the zotero index (~1200 papers as of July 24, currently hosted on HF) ----
21
+ def download_data():
22
+ snapshot_download(
23
+ repo_id="rbiswasfc/zotero_db",
24
+ repo_type="dataset",
25
+ local_dir="./data",
26
+ token=os.environ["HF_TOKEN"],
27
+ )
28
+ print("Data downloaded!")
29
+
30
+
31
+ if not os.path.exists(
32
+ "./data/.lancedb_zotero_v0"
33
+ ): # TODO: implement a better check / refresh mechanism
34
+ download_data()
35
+
36
+
37
+ # cohere embedding utils ----
38
+ @register("coherev3")
39
+ class CohereEmbeddingFunction_2(TextEmbeddingFunction):
40
+ name: str = "embed-english-v3.0"
41
+ client: ClassVar = None
42
+
43
+ def ndims(self):
44
+ return 768
45
+
46
+ def generate_embeddings(self, texts):
47
+ """
48
+ Get the embeddings for the given texts
49
+ Parameters
50
+ ----------
51
+ texts: list[str] or np.ndarray (of str)
52
+ The texts to embed
53
+ """
54
+ # TODO retry, rate limit, token limit
55
+ self._init_client()
56
+ rs = CohereEmbeddingFunction_2.client.embed(
57
+ texts=texts, model=self.name, input_type="search_document"
58
+ )
59
+
60
+ return [emb for emb in rs.embeddings]
61
+
62
+ def _init_client(self):
63
+ cohere = attempt_import_or_raise("cohere")
64
+ if CohereEmbeddingFunction_2.client is None:
65
+ CohereEmbeddingFunction_2.client = cohere.Client(
66
+ os.environ["COHERE_API_KEY"]
67
+ )
68
+
69
+
70
+ COHERE_EMBEDDER = CohereEmbeddingFunction_2.create()
71
+
72
+
73
+ # LanceDB model ----
74
+ class ArxivModel(LanceModel):
75
+ text: str = COHERE_EMBEDDER.SourceField()
76
+ vector: Vector(1024) = COHERE_EMBEDDER.VectorField()
77
+ title: str
78
+ paper_title: str
79
+ content_type: str
80
+ arxiv_id: str
81
+
82
+
83
+ VERSION = "0.0.0a"
84
+ DB = lancedb.connect("./data/.lancedb_zotero_v0")
85
+ ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json")
86
+ RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()}
87
+ TBL = DB.open_table("arxiv_zotero_v0")
88
+
89
+
90
+ # format results ----
91
+ def _format_results(arxiv_refs):
92
+ results = []
93
+ for arx_id, paper_title in arxiv_refs.items():
94
+ abstract = ID_TO_ABSTRACT.get(arx_id, "")
95
+ # these are all ugly hacks because the data preprocessing is poor. to be fixed v soon.
96
+ if "Abstract\n\n" in abstract:
97
+ abstract = abstract.split("Abstract\n\n")[-1]
98
+ if paper_title in abstract:
99
+ abstract = abstract.split(paper_title)[-1]
100
+ if abstract.startswith("\n"):
101
+ abstract = abstract[1:]
102
+ if "\n\n" in abstract[:20]:
103
+ abstract = "\n\n".join(abstract.split("\n\n")[1:])
104
+ result = {
105
+ "title": paper_title,
106
+ "url": f"https://arxiv.org/abs/{arx_id}",
107
+ "abstract": abstract,
108
+ }
109
+ results.append(result)
110
+
111
+ return results
112
+
113
+
114
+ # Search logic ----
115
+ def query_db(query: str, k: int = 10, reranker: str = "cohere"):
116
+ raw_results = TBL.search(query, query_type="hybrid").limit(k)
117
+ if reranker is not None:
118
+ ranked_results = raw_results.rerank(reranker=RERANKERS[reranker])
119
+ else:
120
+ ranked_results = raw_results
121
+
122
+ ranked_results = ranked_results.to_pandas()
123
+ top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"})
124
+ top_results = top_results.sort_values(by="_relevance_score", ascending=False).head(
125
+ 3
126
+ )
127
+ top_results_dict = {
128
+ row["arxiv_id"]: row["paper_title"]
129
+ for index, row in ranked_results.iterrows()
130
+ if row["arxiv_id"] in top_results.index
131
+ }
132
+
133
+ final_results = _format_results(top_results_dict)
134
+ return final_results
135
+
136
+
137
+ ###########################################################################
138
+ # FastHTML app -----
139
+ ###########################################################################
140
+
141
+ style = Style(
142
+ """
143
+ :root {
144
+ color-scheme: dark;
145
+ }
146
+ body {
147
+ max-width: 1200px;
148
+ margin: 0 auto;
149
+ padding: 20px;
150
+ line-height: 1.6;
151
+ }
152
+ #query {
153
+ width: 100%;
154
+ margin-bottom: 1rem;
155
+ }
156
+ #search-form button {
157
+ width: 100%;
158
+ }
159
+ #search-results, #log-entries {
160
+ margin-top: 2rem;
161
+ }
162
+ .log-entry {
163
+ border: 1px solid #ccc;
164
+ padding: 10px;
165
+ margin-bottom: 10px;
166
+ }
167
+ .log-entry pre {
168
+ white-space: pre-wrap;
169
+ word-wrap: break-word;
170
+ }
171
+ """
172
+ )
173
+
174
+ # get the fast app and route
175
+ app, rt = fast_app(live=True, hdrs=(style,))
176
+
177
+ # Initialize a database to store search logs --
178
+ db = database("data/search_logs.db")
179
+ search_logs = db.t.search_logs
180
+ if search_logs not in db.t:
181
+ search_logs.create(
182
+ id=int,
183
+ timestamp=str,
184
+ query=str,
185
+ results=str,
186
+ pk="id",
187
+ )
188
+ SearchLog = search_logs.dataclass()
189
+
190
+
191
+ def insert_log_entry(log_entry):
192
+ "Insert a log entry into the database"
193
+ return search_logs.insert(
194
+ SearchLog(
195
+ timestamp=log_entry["timestamp"].isoformat(),
196
+ query=log_entry["query"],
197
+ results=json.dumps(log_entry["results"]),
198
+ )
199
+ )
200
+
201
+
202
+ @rt("/")
203
+ async def get():
204
+ query_form = Form(
205
+ Textarea(id="query", name="query", placeholder="Enter your query..."),
206
+ Button("Submit", type="submit"),
207
+ id="search-form",
208
+ hx_post="/search",
209
+ hx_target="#search-results",
210
+ )
211
+
212
+ # results_div = Div(H2("Search Results"), Div(id="search-results", cls="results-container"))
213
+ results_div = Div(Div(id="search-results", cls="results-container"))
214
+
215
+ view_logs_link = A("View Logs", href="/logs", cls="view-logs-link")
216
+
217
+ return Titled(
218
+ "Zotero Search", Div(query_form, results_div, view_logs_link, cls="container")
219
+ )
220
+
221
+
222
+ def SearchResult(result):
223
+ "Custom component for displaying a search result"
224
+ return Card(
225
+ H4(A(result["title"], href=result["url"], target="_blank")),
226
+ P(result["abstract"]),
227
+ footer=A("Read more →", href=result["url"], target="_blank"),
228
+ )
229
+
230
+
231
+ def log_query_and_results(query, results):
232
+ log_entry = {
233
+ "timestamp": datetime.now(),
234
+ "query": query,
235
+ "results": [{"title": r["title"], "url": r["url"]} for r in results],
236
+ }
237
+ insert_log_entry(log_entry)
238
+
239
+
240
+ @rt("/search")
241
+ async def post(query: str):
242
+ results = query_db(query)
243
+ log_query_and_results(query, results)
244
+
245
+ return Div(*[SearchResult(r) for r in results], id="search-results")
246
+
247
+
248
+ def LogEntry(entry):
249
+ return Div(
250
+ H4(f"Query: {entry.query}"),
251
+ P(f"Timestamp: {entry.timestamp}"),
252
+ H5("Results:"),
253
+ Pre(entry.results),
254
+ cls="log-entry",
255
+ )
256
+
257
+
258
+ @rt("/logs")
259
+ async def get():
260
+ logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs
261
+ log_entries = [LogEntry(log) for log in logs]
262
+ return Titled(
263
+ "Logs",
264
+ Div(
265
+ H2("Recent Search Logs"),
266
+ Div(*log_entries, id="log-entries"),
267
+ A("Back to Search", href="/", cls="back-link"),
268
+ cls="container",
269
+ ),
270
+ )
271
+
272
+
273
+ if __name__ == "__main__":
274
+ import uvicorn
275
+
276
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
favicon.ico ADDED
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python-fasthtml
2
+ uvicorn>=0.29
3
+ lancedb
4
+ srsly
5
+ cohere
6
+ python-dotenv
7
+ tantivy
8
+ beautifulsoup4
9
+ retry
10
+ transformers
11
+ torch