Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
•
a240da9
1
Parent(s):
7281bd6
Initial commit
Browse files- .gitignore +11 -0
- README.md +19 -0
- environment.yml +97 -0
- pyproject.toml +14 -0
- requests.org +57 -0
- requirements-dev.txt +4 -0
- requirements.txt +5 -0
- src/base.py +37 -0
- src/db.py +102 -0
- src/ml.py +143 -0
- src/webservice.py +112 -0
- src/worker.py +117 -0
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
.idea
|
3 |
+
*.log
|
4 |
+
tmp/
|
5 |
+
|
6 |
+
*.py[cod]
|
7 |
+
*.egg
|
8 |
+
build
|
9 |
+
htmlcov
|
10 |
+
|
11 |
+
*.db
|
README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dump your knowledge, let AI refine it
|
2 |
+
|
3 |
+
## Starting
|
4 |
+
|
5 |
+
Install stuff, then, in one terminal, start the background worker:
|
6 |
+
|
7 |
+
```sh
|
8 |
+
cd src
|
9 |
+
python worker.py
|
10 |
+
```
|
11 |
+
|
12 |
+
Start the web server:
|
13 |
+
|
14 |
+
```sh
|
15 |
+
cd src
|
16 |
+
uvicorn webservice:app --reload --port 8080
|
17 |
+
```
|
18 |
+
|
19 |
+
For example requests, check `requests.org`.
|
environment.yml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: gistillery
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- _openmp_mutex=5.1=1_gnu
|
9 |
+
- blas=1.0=mkl
|
10 |
+
- bzip2=1.0.8=h7b6447c_0
|
11 |
+
- ca-certificates=2023.01.10=h06a4308_0
|
12 |
+
- cuda-cudart=11.7.99=0
|
13 |
+
- cuda-cupti=11.7.101=0
|
14 |
+
- cuda-libraries=11.7.1=0
|
15 |
+
- cuda-nvrtc=11.7.99=0
|
16 |
+
- cuda-nvtx=11.7.91=0
|
17 |
+
- cuda-runtime=11.7.1=0
|
18 |
+
- filelock=3.9.0=py310h06a4308_0
|
19 |
+
- gmp=6.2.1=h295c915_3
|
20 |
+
- gmpy2=2.1.2=py310heeb90bb_0
|
21 |
+
- intel-openmp=2023.1.0=hdb19cb5_46305
|
22 |
+
- jinja2=3.1.2=py310h06a4308_0
|
23 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
24 |
+
- libcublas=11.10.3.66=0
|
25 |
+
- libcufft=10.7.2.124=h4fbf590_0
|
26 |
+
- libcufile=1.6.1.9=0
|
27 |
+
- libcurand=10.3.2.106=0
|
28 |
+
- libcusolver=11.4.0.1=0
|
29 |
+
- libcusparse=11.7.4.91=0
|
30 |
+
- libffi=3.4.2=h6a678d5_6
|
31 |
+
- libgcc-ng=11.2.0=h1234567_1
|
32 |
+
- libgomp=11.2.0=h1234567_1
|
33 |
+
- libnpp=11.7.4.75=0
|
34 |
+
- libnvjpeg=11.8.0.2=0
|
35 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
36 |
+
- libuuid=1.41.5=h5eee18b_0
|
37 |
+
- markupsafe=2.1.1=py310h7f8727e_0
|
38 |
+
- mkl=2023.1.0=h6d00ec8_46342
|
39 |
+
- mpc=1.1.0=h10f8cd9_1
|
40 |
+
- mpfr=4.0.2=hb69a4c5_1
|
41 |
+
- ncurses=6.4=h6a678d5_0
|
42 |
+
- networkx=2.8.4=py310h06a4308_1
|
43 |
+
- openssl=1.1.1t=h7f8727e_0
|
44 |
+
- pip=23.0.1=py310h06a4308_0
|
45 |
+
- python=3.10.11=h7a1cb2a_2
|
46 |
+
- pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0
|
47 |
+
- pytorch-cuda=11.7=h778d358_3
|
48 |
+
- pytorch-mutex=1.0=cuda
|
49 |
+
- readline=8.2=h5eee18b_0
|
50 |
+
- setuptools=66.0.0=py310h06a4308_0
|
51 |
+
- sqlite=3.41.2=h5eee18b_0
|
52 |
+
- sympy=1.11.1=py310h06a4308_0
|
53 |
+
- tbb=2021.8.0=hdb19cb5_0
|
54 |
+
- tk=8.6.12=h1ccaba5_0
|
55 |
+
- torchtriton=2.0.0=py310
|
56 |
+
- typing_extensions=4.5.0=py310h06a4308_0
|
57 |
+
- tzdata=2023c=h04d1e81_0
|
58 |
+
- wheel=0.38.4=py310h06a4308_0
|
59 |
+
- xz=5.4.2=h5eee18b_0
|
60 |
+
- zlib=1.2.13=h5eee18b_0
|
61 |
+
- pip:
|
62 |
+
- anyio==3.6.2
|
63 |
+
- black==23.3.0
|
64 |
+
- certifi==2022.12.7
|
65 |
+
- charset-normalizer==3.1.0
|
66 |
+
- click==8.1.3
|
67 |
+
- fastapi==0.95.1
|
68 |
+
- fsspec==2023.4.0
|
69 |
+
- h11==0.14.0
|
70 |
+
- httptools==0.5.0
|
71 |
+
- huggingface-hub==0.14.1
|
72 |
+
- idna==3.4
|
73 |
+
- mpmath==1.2.1
|
74 |
+
- mypy==1.2.0
|
75 |
+
- mypy-extensions==1.0.0
|
76 |
+
- numpy==1.24.3
|
77 |
+
- packaging==23.1
|
78 |
+
- pathspec==0.11.1
|
79 |
+
- platformdirs==3.5.0
|
80 |
+
- pydantic==1.10.7
|
81 |
+
- python-dotenv==1.0.0
|
82 |
+
- pyyaml==6.0
|
83 |
+
- regex==2023.5.5
|
84 |
+
- requests==2.29.0
|
85 |
+
- ruff==0.0.264
|
86 |
+
- sniffio==1.3.0
|
87 |
+
- starlette==0.26.1
|
88 |
+
- tokenizers==0.13.3
|
89 |
+
- tomli==2.0.1
|
90 |
+
- tqdm==4.65.0
|
91 |
+
- transformers==4.28.1
|
92 |
+
- urllib3==1.26.15
|
93 |
+
- uvicorn==0.22.0
|
94 |
+
- uvloop==0.17.0
|
95 |
+
- watchfiles==0.19.0
|
96 |
+
- websockets==11.0.2
|
97 |
+
prefix: /home/vinh/anaconda3/envs/gistillery
|
pyproject.toml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 88
|
3 |
+
target_version = ['py310', 'py311']
|
4 |
+
preview = true
|
5 |
+
|
6 |
+
[tool.isort]
|
7 |
+
profile = "black"
|
8 |
+
|
9 |
+
[tool.mypy]
|
10 |
+
no_implicit_optional = true
|
11 |
+
strict = true
|
12 |
+
|
13 |
+
[[tool.mypy-transformers]]
|
14 |
+
ignore_missing_imports = true
|
requests.org
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#+title: Requests
|
2 |
+
|
3 |
+
#+begin_src bash
|
4 |
+
curl -X 'GET' \
|
5 |
+
'http://localhost:8080/clear/' \
|
6 |
+
-H 'accept: application/json'
|
7 |
+
#+end_src
|
8 |
+
|
9 |
+
#+RESULTS:
|
10 |
+
: OK
|
11 |
+
|
12 |
+
#+begin_src bash
|
13 |
+
# curl command to localhost and post the message "hi there"
|
14 |
+
curl -X 'POST' \
|
15 |
+
'http://localhost:8080/submit/' \
|
16 |
+
-H 'accept: application/json' \
|
17 |
+
-H 'Content-Type: application/json' \
|
18 |
+
-d '{
|
19 |
+
"author": "ben",
|
20 |
+
"content": "SAN FRANCISCO, May 2, 2023 PRNewswire -- GitLab Inc., the most comprehensive, scalable enterprise DevSecOps platform for software innovation, and Google Cloud today announced an extension of its strategic partnership to deliver secure AI offerings to the enterprise. GitLab is trusted by more than 50% of the Fortune 100 to secure and protect their most valuable assets, and leads with a privacy-first approach to AI. By leveraging Google Cloud'\''s customizable foundation models and open generative AI infrastructure, GitLab will provide customers with AI-assisted features directly within the enterprise DevSecOps platform."
|
21 |
+
}'
|
22 |
+
#+end_src
|
23 |
+
|
24 |
+
#+RESULTS:
|
25 |
+
: Submitted job 04d44970ced1473999dfab77b02202b8
|
26 |
+
|
27 |
+
#+begin_src bash
|
28 |
+
curl -X 'POST' \
|
29 |
+
'http://localhost:8080/submit/' \
|
30 |
+
-H 'accept: application/json' \
|
31 |
+
-H 'Content-Type: application/json' \
|
32 |
+
-d '{
|
33 |
+
"author": "ben",
|
34 |
+
"content": "In literature discussing why ChatGPT is able to capture so much of our imagination, I often come across two narratives: Scale: throwing more data and compute at it. UX: moving from a prompt interface to a more natural chat interface. A narrative that is often glossed over in the demo frenzy is the incredible technical creativity that went into making models like ChatGPT work. One such cool idea is RLHF (Reinforcement Learning from Human Feedback): incorporating reinforcement learning and human feedback into NLP. RL has been notoriously difficult to work with, and therefore, mostly confined to gaming and simulated environments like Atari or MuJoCo. Just five years ago, both RL and NLP were progressing pretty much orthogonally – different stacks, different techniques, and different experimentation setups. It’s impressive to see it work in a new domain at a massive scale. So, how exactly does RLHF work? Why does it work? This post will discuss the answers to those questions."
|
35 |
+
}'
|
36 |
+
#+end_src
|
37 |
+
|
38 |
+
#+RESULTS:
|
39 |
+
: Submitted job 3cc2104aec0748b1bd5743c321b169ac
|
40 |
+
|
41 |
+
#+begin_src bash
|
42 |
+
curl -X 'GET' \
|
43 |
+
'http://localhost:8080/check_status/22b158499b744f42918912cd387fd657' \
|
44 |
+
-H 'accept: application/json'
|
45 |
+
#+end_src
|
46 |
+
|
47 |
+
#+RESULTS:
|
48 |
+
| {"id":"22b158499b744f42918912cd387fd657" | status:"done" | last_updated:"2023-05-05T14:54:11"} |
|
49 |
+
|
50 |
+
#+begin_src bash
|
51 |
+
curl -X 'GET' \
|
52 |
+
'http://localhost:8080/recent/' \
|
53 |
+
-H 'accept: application/json'
|
54 |
+
#+end_src
|
55 |
+
|
56 |
+
#+RESULTS:
|
57 |
+
| [{"id":"3cc2104aec0748b1bd5743c321b169ac" | author:"ben" | summary:"A new approach to NLP that incorporates reinforcement learning and human feedback. How does it work? Why does it work? In this post | I’ll explain how it works. RLHF is a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback. It’s a new approach to NLP that incorporates reinforcement learning and human feedback." | tags:["#general" | #rlhf] | date:"2023-05-05T14:56:32"} | {"id":"04d44970ced1473999dfab77b02202b8" | author:"ben" | summary:"GitLab | the most comprehensive | scalable enterprise DevSecOps platform for software innovation | and Google Cloud today announced an extension of their strategic partnership to deliver secure AI offerings to the enterprise. By leveraging Google Cloud's customizable foundation models and open generative AI infrastructure | GitLab will provide customers with AI-assisted features directly within the enterprise DevSecOps platform. The company's AI capabilities are designed to help enterprises improve productivity and reduce costs." | tags:["#general"] | date:"2023-05-05T14:56:31"}] |
|
requirements-dev.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
black
|
2 |
+
isort
|
3 |
+
mypy
|
4 |
+
ruff
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
httpx
|
3 |
+
uvicorn[standard]
|
4 |
+
torch
|
5 |
+
transformers
|
src/base.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime as dt
|
2 |
+
import enum
|
3 |
+
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
|
7 |
+
class RequestInput(BaseModel):
|
8 |
+
author: str
|
9 |
+
content: str
|
10 |
+
|
11 |
+
|
12 |
+
class EntriesResult(BaseModel):
|
13 |
+
id: str
|
14 |
+
author: str
|
15 |
+
summary: str
|
16 |
+
tags: list[str]
|
17 |
+
date: dt.datetime
|
18 |
+
|
19 |
+
|
20 |
+
class JobInput(BaseModel):
|
21 |
+
id: str
|
22 |
+
author: str
|
23 |
+
content: str
|
24 |
+
|
25 |
+
|
26 |
+
class JobStatus(str, enum.Enum):
|
27 |
+
pending = "pending"
|
28 |
+
done = "done"
|
29 |
+
failed = "failed"
|
30 |
+
cancelled = "cancelled"
|
31 |
+
not_found = "not found"
|
32 |
+
|
33 |
+
|
34 |
+
class JobStatusResult(BaseModel):
|
35 |
+
id: str
|
36 |
+
status: JobStatus
|
37 |
+
last_updated: dt.datetime | None
|
src/db.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sqlite3
|
3 |
+
from contextlib import contextmanager
|
4 |
+
from typing import Generator
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
logger.setLevel(logging.DEBUG)
|
8 |
+
|
9 |
+
|
10 |
+
schema_entries = """
|
11 |
+
CREATE TABLE entries
|
12 |
+
(
|
13 |
+
id TEXT PRIMARY KEY,
|
14 |
+
author TEXT NOT NULL,
|
15 |
+
source TEXT NOT NULL,
|
16 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
17 |
+
)
|
18 |
+
"""
|
19 |
+
|
20 |
+
# create schema for 'summary' table, id is a uuid4
|
21 |
+
schema_summary = """
|
22 |
+
CREATE TABLE summaries
|
23 |
+
(
|
24 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
25 |
+
entry_id TEXT NOT NULL,
|
26 |
+
summary TEXT NOT NULL,
|
27 |
+
summarizer_name TEXT NOT NULL,
|
28 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
29 |
+
FOREIGN KEY(entry_id) REFERENCES entries(id)
|
30 |
+
)
|
31 |
+
"""
|
32 |
+
|
33 |
+
schema_tag = """
|
34 |
+
CREATE TABLE tags
|
35 |
+
(
|
36 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
37 |
+
entry_id TEXT NOT NULL,
|
38 |
+
tag TEXT NOT NULL,
|
39 |
+
tagger_name TEXT NOT NULL,
|
40 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
41 |
+
FOREIGN KEY(entry_id) REFERENCES entries(id)
|
42 |
+
)
|
43 |
+
"""
|
44 |
+
|
45 |
+
schema_job = """
|
46 |
+
CREATE TABLE jobs
|
47 |
+
(
|
48 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
49 |
+
entry_id TEXT NOT NULL,
|
50 |
+
status TEXT NOT NULL DEFAULT 'pending',
|
51 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
52 |
+
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
53 |
+
FOREIGN KEY(entry_id) REFERENCES entries(id)
|
54 |
+
)
|
55 |
+
"""
|
56 |
+
|
57 |
+
TABLES = {
|
58 |
+
"entries": schema_entries,
|
59 |
+
"summaries": schema_summary,
|
60 |
+
"tags": schema_tag,
|
61 |
+
"jobs": schema_job,
|
62 |
+
}
|
63 |
+
TABLES_CREATED = False
|
64 |
+
|
65 |
+
|
66 |
+
def _get_db_connection() -> sqlite3.Connection:
|
67 |
+
global TABLES_CREATED
|
68 |
+
|
69 |
+
# sqlite cannot deal with concurrent access, so we set a big timeout
|
70 |
+
conn = sqlite3.connect("sqlite-data.db", timeout=30)
|
71 |
+
if TABLES_CREATED:
|
72 |
+
return conn
|
73 |
+
|
74 |
+
cursor = conn.cursor()
|
75 |
+
|
76 |
+
# create tables if needed
|
77 |
+
for table_name, schema in TABLES.items():
|
78 |
+
cursor.execute(
|
79 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
80 |
+
(table_name,),
|
81 |
+
)
|
82 |
+
table_exists = cursor.fetchone() is not None
|
83 |
+
if not table_exists:
|
84 |
+
logger.info(f"'{table_name}' table does not exist, creating it now...")
|
85 |
+
cursor.execute(schema)
|
86 |
+
conn.commit()
|
87 |
+
logger.info("done")
|
88 |
+
|
89 |
+
TABLES_CREATED = True
|
90 |
+
return conn
|
91 |
+
|
92 |
+
|
93 |
+
@contextmanager
|
94 |
+
def get_db_cursor() -> Generator[sqlite3.Cursor, None, None]:
|
95 |
+
conn = _get_db_connection()
|
96 |
+
cursor = conn.cursor()
|
97 |
+
try:
|
98 |
+
yield cursor
|
99 |
+
finally:
|
100 |
+
conn.commit()
|
101 |
+
cursor.close()
|
102 |
+
conn.close()
|
src/ml.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
|
5 |
+
import httpx
|
6 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
7 |
+
|
8 |
+
from base import JobInput
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logger.setLevel(logging.DEBUG)
|
12 |
+
|
13 |
+
MODEL_NAME = "google/flan-t5-large"
|
14 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
16 |
+
|
17 |
+
|
18 |
+
class Summarizer:
|
19 |
+
def __init__(self) -> None:
|
20 |
+
self.template = "Summarize the text below in two sentences:\n\n{}"
|
21 |
+
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
22 |
+
self.generation_config.max_new_tokens = 200
|
23 |
+
self.generation_config.min_new_tokens = 100
|
24 |
+
self.generation_config.top_k = 5
|
25 |
+
self.generation_config.repetition_penalty = 1.5
|
26 |
+
|
27 |
+
def __call__(self, x: str) -> str:
|
28 |
+
text = self.template.format(x)
|
29 |
+
inputs = tokenizer(text, return_tensors="pt")
|
30 |
+
outputs = model.generate(**inputs, generation_config=self.generation_config)
|
31 |
+
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
32 |
+
assert isinstance(output, str)
|
33 |
+
return output
|
34 |
+
|
35 |
+
def get_name(self) -> str:
|
36 |
+
return f"Summarizer({MODEL_NAME})"
|
37 |
+
|
38 |
+
|
39 |
+
class Tagger:
|
40 |
+
def __init__(self) -> None:
|
41 |
+
self.template = (
|
42 |
+
"Create a list of tags for the text below. The tags should be high level "
|
43 |
+
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
|
44 |
+
)
|
45 |
+
self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
46 |
+
self.generation_config.max_new_tokens = 50
|
47 |
+
self.generation_config.min_new_tokens = 25
|
48 |
+
# increase the temperature to make the model more creative
|
49 |
+
self.generation_config.temperature = 1.5
|
50 |
+
|
51 |
+
def _extract_tags(self, text: str) -> list[str]:
|
52 |
+
tags = set()
|
53 |
+
for tag in text.split():
|
54 |
+
if tag.startswith("#"):
|
55 |
+
tags.add(tag.lower())
|
56 |
+
return sorted(tags)
|
57 |
+
|
58 |
+
def __call__(self, x: str) -> list[str]:
|
59 |
+
text = self.template.format(x)
|
60 |
+
inputs = tokenizer(text, return_tensors="pt")
|
61 |
+
outputs = model.generate(**inputs, generation_config=self.generation_config)
|
62 |
+
output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
63 |
+
tags = self._extract_tags(output)
|
64 |
+
return tags
|
65 |
+
|
66 |
+
def get_name(self) -> str:
|
67 |
+
return f"Tagger({MODEL_NAME})"
|
68 |
+
|
69 |
+
|
70 |
+
class Processor(abc.ABC):
|
71 |
+
def __call__(self, job: JobInput) -> str:
|
72 |
+
_id = job.id
|
73 |
+
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
|
74 |
+
result = self.process(job)
|
75 |
+
logger.info(f"Finished processing input (id={_id[:8]})")
|
76 |
+
return result
|
77 |
+
|
78 |
+
def process(self, input: JobInput) -> str:
|
79 |
+
raise NotImplementedError
|
80 |
+
|
81 |
+
def match(self, input: JobInput) -> bool:
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
def get_name(self) -> str:
|
85 |
+
raise NotImplementedError
|
86 |
+
|
87 |
+
|
88 |
+
class RawProcessor(Processor):
|
89 |
+
def match(self, input: JobInput) -> bool:
|
90 |
+
return True
|
91 |
+
|
92 |
+
def process(self, input: JobInput) -> str:
|
93 |
+
return input.content
|
94 |
+
|
95 |
+
def get_name(self) -> str:
|
96 |
+
return self.__class__.__name__
|
97 |
+
|
98 |
+
|
99 |
+
class PlainUrlProcessor(Processor):
|
100 |
+
def __init__(self) -> None:
|
101 |
+
self.client = httpx.Client()
|
102 |
+
self.regex = re.compile(r"(https?://[^\s]+)")
|
103 |
+
self.url = None
|
104 |
+
self.template = "{url}\n\n{content}"
|
105 |
+
|
106 |
+
def match(self, input: JobInput) -> bool:
|
107 |
+
urls = list(self.regex.findall(input.content))
|
108 |
+
if len(urls) == 1:
|
109 |
+
self.url = urls[0]
|
110 |
+
return True
|
111 |
+
return False
|
112 |
+
|
113 |
+
def process(self, input: JobInput) -> str:
|
114 |
+
"""Get content of website and return it as string"""
|
115 |
+
assert isinstance(self.url, str)
|
116 |
+
text = self.client.get(self.url).text
|
117 |
+
assert isinstance(text, str)
|
118 |
+
text = self.template.format(url=self.url, content=text)
|
119 |
+
return text
|
120 |
+
|
121 |
+
def get_name(self) -> str:
|
122 |
+
return self.__class__.__name__
|
123 |
+
|
124 |
+
|
125 |
+
class ProcessorRegistry:
|
126 |
+
def __init__(self) -> None:
|
127 |
+
self.registry: list[Processor] = []
|
128 |
+
self.default_registry: list[Processor] = []
|
129 |
+
self.set_default_processors()
|
130 |
+
|
131 |
+
def set_default_processors(self) -> None:
|
132 |
+
self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
|
133 |
+
|
134 |
+
def register(self, processor: Processor) -> None:
|
135 |
+
self.registry.append(processor)
|
136 |
+
|
137 |
+
def dispatch(self, input: JobInput) -> Processor:
|
138 |
+
for processor in self.registry + self.default_registry:
|
139 |
+
if processor.match(input):
|
140 |
+
return processor
|
141 |
+
|
142 |
+
# should never be requires, but eh
|
143 |
+
return RawProcessor()
|
src/webservice.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import uuid
|
3 |
+
|
4 |
+
from fastapi import FastAPI
|
5 |
+
|
6 |
+
from base import EntriesResult, JobStatus, JobStatusResult, RequestInput
|
7 |
+
from db import TABLES, get_db_cursor
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logger.setLevel(logging.DEBUG)
|
12 |
+
|
13 |
+
|
14 |
+
app = FastAPI()
|
15 |
+
|
16 |
+
|
17 |
+
@app.post("/submit/")
|
18 |
+
def submit_job(input: RequestInput) -> str:
|
19 |
+
# submit a new job, poor man's job queue
|
20 |
+
_id = uuid.uuid4().hex
|
21 |
+
logger.info(f"Submitting job for (_id={_id[:8]})")
|
22 |
+
|
23 |
+
with get_db_cursor() as cursor:
|
24 |
+
# create a job
|
25 |
+
query = "INSERT INTO jobs (entry_id, status) VALUES (?, ?)"
|
26 |
+
cursor.execute(query, (_id, "pending"))
|
27 |
+
# create an entry
|
28 |
+
query = "INSERT INTO entries (id, author, source) VALUES (?, ?, ?)"
|
29 |
+
cursor.execute(query, (_id, input.author, input.content))
|
30 |
+
|
31 |
+
return f"Submitted job {_id}"
|
32 |
+
|
33 |
+
|
34 |
+
@app.get("/check_status/{_id}")
|
35 |
+
def check_status(_id: str) -> JobStatusResult:
|
36 |
+
with get_db_cursor() as cursor:
|
37 |
+
cursor.execute(
|
38 |
+
"SELECT status, last_updated FROM jobs WHERE entry_id = ?", (_id,)
|
39 |
+
)
|
40 |
+
result = cursor.fetchone()
|
41 |
+
|
42 |
+
if result is None:
|
43 |
+
return JobStatusResult(id=_id, status=JobStatus.not_found, last_updated=None)
|
44 |
+
|
45 |
+
status, last_updated = result
|
46 |
+
return JobStatusResult(id=_id, status=status, last_updated=last_updated)
|
47 |
+
|
48 |
+
|
49 |
+
@app.get("/recent/")
|
50 |
+
def recent() -> list[EntriesResult]:
|
51 |
+
with get_db_cursor() as cursor:
|
52 |
+
# get the last 10 entries, join summary and tag, where each tag is
|
53 |
+
# joined to a comma separated str
|
54 |
+
cursor.execute("""
|
55 |
+
SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ","), e.created_at
|
56 |
+
FROM entries e
|
57 |
+
JOIN summaries s ON e.id = s.entry_id
|
58 |
+
JOIN tags t ON e.id = t.entry_id
|
59 |
+
GROUP BY e.id
|
60 |
+
ORDER BY e.created_at DESC
|
61 |
+
LIMIT 10
|
62 |
+
""")
|
63 |
+
results = cursor.fetchall()
|
64 |
+
|
65 |
+
entries = []
|
66 |
+
for _id, author, summary, tags, date in results:
|
67 |
+
entry = EntriesResult(
|
68 |
+
id=_id, author=author, summary=summary, tags=tags.split(","), date=date
|
69 |
+
)
|
70 |
+
entries.append(entry)
|
71 |
+
return entries
|
72 |
+
|
73 |
+
|
74 |
+
@app.get("/recent/{tag}")
|
75 |
+
def recent_tag(tag: str) -> list[EntriesResult]:
|
76 |
+
if not tag.startswith("#"):
|
77 |
+
tag = "#" + tag
|
78 |
+
|
79 |
+
# same as recent, but filter by tag
|
80 |
+
with get_db_cursor() as cursor:
|
81 |
+
cursor.execute(
|
82 |
+
"""
|
83 |
+
SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ","), e.created_at
|
84 |
+
FROM entries e
|
85 |
+
JOIN summaries s ON e.id = s.entry_id
|
86 |
+
JOIN tags t ON e.id = t.entry_id
|
87 |
+
WHERE t.tag = ?
|
88 |
+
GROUP BY e.id
|
89 |
+
ORDER BY e.created_at DESC
|
90 |
+
LIMIT 10
|
91 |
+
""",
|
92 |
+
(tag,),
|
93 |
+
)
|
94 |
+
results = cursor.fetchall()
|
95 |
+
|
96 |
+
entries = []
|
97 |
+
for _id, author, summary, tags, date in results:
|
98 |
+
entry = EntriesResult(
|
99 |
+
id=_id, author=author, summary=summary, tags=tags.split(","), date=date
|
100 |
+
)
|
101 |
+
entries.append(entry)
|
102 |
+
return entries
|
103 |
+
|
104 |
+
|
105 |
+
@app.get("/clear/")
|
106 |
+
def clear() -> str:
|
107 |
+
# clear all tables
|
108 |
+
logger.warning("Clearing all tables")
|
109 |
+
with get_db_cursor() as cursor:
|
110 |
+
for table_name in TABLES:
|
111 |
+
cursor.execute(f"DELETE FROM {table_name}")
|
112 |
+
return "OK"
|
src/worker.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
from base import JobInput
|
4 |
+
from db import get_db_cursor
|
5 |
+
from ml import ProcessorRegistry, Summarizer, Tagger
|
6 |
+
|
7 |
+
SLEEP_INTERVAL = 5
|
8 |
+
|
9 |
+
|
10 |
+
processor_registry = ProcessorRegistry()
|
11 |
+
summarizer = Summarizer()
|
12 |
+
tagger = Tagger()
|
13 |
+
print("loaded ML models")
|
14 |
+
|
15 |
+
|
16 |
+
def check_pending_jobs() -> list[JobInput]:
|
17 |
+
"""Check DB for pending jobs"""
|
18 |
+
with get_db_cursor() as cursor:
|
19 |
+
# fetch pending jobs, join authro and content from entries table
|
20 |
+
query = """
|
21 |
+
SELECT j.entry_id, e.author, e.source
|
22 |
+
FROM jobs j
|
23 |
+
JOIN entries e
|
24 |
+
ON j.entry_id = e.id
|
25 |
+
WHERE j.status = 'pending'
|
26 |
+
"""
|
27 |
+
res = list(cursor.execute(query))
|
28 |
+
return [
|
29 |
+
JobInput(id=_id, author=author, content=content) for _id, author, content in res
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def store(
|
34 |
+
job: JobInput,
|
35 |
+
*,
|
36 |
+
summary: str,
|
37 |
+
tags: list[str],
|
38 |
+
processor_name: str,
|
39 |
+
summarizer_name: str,
|
40 |
+
tagger_name: str,
|
41 |
+
) -> None:
|
42 |
+
with get_db_cursor() as cursor:
|
43 |
+
# write to entries, summary, tags tables
|
44 |
+
cursor.execute(
|
45 |
+
(
|
46 |
+
"INSERT INTO summaries (entry_id, summary, summarizer_name)"
|
47 |
+
" VALUES (?, ?, ?)"
|
48 |
+
),
|
49 |
+
(job.id, summary, summarizer_name),
|
50 |
+
)
|
51 |
+
cursor.executemany(
|
52 |
+
"INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
|
53 |
+
[(job.id, tag, tagger_name) for tag in tags],
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def process_job(job: JobInput) -> None:
|
58 |
+
tic = time.perf_counter()
|
59 |
+
print(f"Processing job for (id={job.id[:8]})")
|
60 |
+
|
61 |
+
# care: acquire cursor (which leads to locking) as late as possible, since
|
62 |
+
# the processing and we don't want to block other workers during that time
|
63 |
+
try:
|
64 |
+
processor = processor_registry.dispatch(job)
|
65 |
+
processor_name = processor.get_name()
|
66 |
+
processed = processor(job)
|
67 |
+
|
68 |
+
tagger_name = tagger.get_name()
|
69 |
+
tags = tagger(processed)
|
70 |
+
|
71 |
+
summarizer_name = summarizer.get_name()
|
72 |
+
summary = summarizer(processed)
|
73 |
+
|
74 |
+
store(
|
75 |
+
job,
|
76 |
+
summary=summary,
|
77 |
+
tags=tags,
|
78 |
+
processor_name=processor_name,
|
79 |
+
summarizer_name=summarizer_name,
|
80 |
+
tagger_name=tagger_name,
|
81 |
+
)
|
82 |
+
# update job status to done
|
83 |
+
with get_db_cursor() as cursor:
|
84 |
+
cursor.execute(
|
85 |
+
"UPDATE jobs SET status = 'done' WHERE entry_id = ?", (job.id,)
|
86 |
+
)
|
87 |
+
except Exception as e:
|
88 |
+
# update job status to failed
|
89 |
+
with get_db_cursor() as cursor:
|
90 |
+
cursor.execute(
|
91 |
+
"UPDATE jobs SET status = 'failed' WHERE entry_id = ?", (job.id,)
|
92 |
+
)
|
93 |
+
print(f"Failed to process job for (id={job.id[:8]}): {e}")
|
94 |
+
|
95 |
+
toc = time.perf_counter()
|
96 |
+
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
|
97 |
+
|
98 |
+
|
99 |
+
def main() -> None:
|
100 |
+
while True:
|
101 |
+
jobs = check_pending_jobs()
|
102 |
+
if not jobs:
|
103 |
+
print("No pending jobs found, sleeping...")
|
104 |
+
time.sleep(SLEEP_INTERVAL)
|
105 |
+
continue
|
106 |
+
|
107 |
+
print(f"Found {len(jobs)} pending job(s), processing...")
|
108 |
+
for job in jobs:
|
109 |
+
process_job(job)
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
try:
|
114 |
+
main()
|
115 |
+
except KeyboardInterrupt:
|
116 |
+
print("Shutting down...")
|
117 |
+
exit(0)
|