Benjamin Bossan commited on
Commit
6ac056a
1 Parent(s): ba8f25e

Store processed inputs in db

Browse files

Also, use named tuples for row factory.

src/gistillery/db.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
  import sqlite3
 
4
  from contextlib import contextmanager
5
  from typing import Generator
6
 
@@ -57,20 +58,41 @@ CREATE TABLE jobs
57
  )
58
  """
59
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  TABLES = {
61
- "entries": schema_entries,
62
- "summaries": schema_summary,
63
- "tags": schema_tag,
64
- "jobs": schema_job,
 
65
  }
66
  TABLES_CREATED = False
67
 
68
 
 
 
 
 
 
 
 
69
  def _get_db_connection() -> sqlite3.Connection:
70
  global TABLES_CREATED
71
 
72
  # sqlite cannot deal with concurrent access, so we set a big timeout
73
  conn = sqlite3.connect(db_file, timeout=30)
 
74
  if TABLES_CREATED:
75
  return conn
76
 
 
1
  import logging
2
  import os
3
  import sqlite3
4
+ from collections import namedtuple
5
  from contextlib import contextmanager
6
  from typing import Generator
7
 
 
58
  )
59
  """
60
 
61
+ # store the processed inputs
62
+ schema_inputs = """
63
+ CREATE TABLE inputs
64
+ (
65
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
66
+ entry_id TEXT NOT NULL,
67
+ input TEXT NOT NULL,
68
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
69
+ FOREIGN KEY(entry_id) REFERENCES entries(id)
70
+ )
71
+ """
72
+
73
  TABLES = {
74
+ 'entries': schema_entries,
75
+ 'summaries': schema_summary,
76
+ 'tags': schema_tag,
77
+ 'jobs': schema_job,
78
+ 'inputs': schema_inputs,
79
  }
80
  TABLES_CREATED = False
81
 
82
 
83
+ # https://docs.python.org/3/library/sqlite3.html#how-to-create-and-use-row-factories
84
+ def namedtuple_factory(cursor, row):
85
+ fields = [column[0] for column in cursor.description]
86
+ cls = namedtuple("Row", fields)
87
+ return cls._make(row)
88
+
89
+
90
  def _get_db_connection() -> sqlite3.Connection:
91
  global TABLES_CREATED
92
 
93
  # sqlite cannot deal with concurrent access, so we set a big timeout
94
  conn = sqlite3.connect(db_file, timeout=30)
95
+ conn.row_factory = namedtuple_factory
96
  if TABLES_CREATED:
97
  return conn
98
 
src/gistillery/webservice.py CHANGED
@@ -58,7 +58,7 @@ def recent() -> list[EntriesResult]:
58
  # get the last 10 entries, join summary and tag, where each tag is
59
  # joined to a comma separated str
60
  cursor.execute("""
61
- SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ","), e.created_at
62
  FROM entries e
63
  JOIN summaries s ON e.id = s.entry_id
64
  JOIN tags t ON e.id = t.entry_id
@@ -86,7 +86,7 @@ def recent_tag(tag: str) -> list[EntriesResult]:
86
  with get_db_cursor() as cursor:
87
  cursor.execute(
88
  """
89
- SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ","), e.created_at
90
  FROM entries e
91
  JOIN summaries s ON e.id = s.entry_id
92
  JOIN tags t ON e.id = t.entry_id
 
58
  # get the last 10 entries, join summary and tag, where each tag is
59
  # joined to a comma separated str
60
  cursor.execute("""
61
+ SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ",") tags, e.created_at
62
  FROM entries e
63
  JOIN summaries s ON e.id = s.entry_id
64
  JOIN tags t ON e.id = t.entry_id
 
86
  with get_db_cursor() as cursor:
87
  cursor.execute(
88
  """
89
+ SELECT e.id, e.author, s.summary, GROUP_CONCAT(t.tag, ",") tags, e.created_at
90
  FROM entries e
91
  JOIN summaries s ON e.id = s.entry_id
92
  JOIN tags t ON e.id = t.entry_id
src/gistillery/worker.py CHANGED
@@ -33,6 +33,7 @@ def check_pending_jobs() -> list[JobInput]:
33
 
34
  @dataclass
35
  class JobOutput:
 
36
  summary: str
37
  tags: list[str]
38
  processor_name: str
@@ -54,6 +55,7 @@ def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
54
  summary = summarizer(processed)
55
 
56
  return JobOutput(
 
57
  summary=summary,
58
  tags=tags,
59
  processor_name=processor_name,
@@ -64,7 +66,10 @@ def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
64
 
65
  def store(job: JobInput, output: JobOutput) -> None:
66
  with get_db_cursor() as cursor:
67
- # write to entries, summary, tags tables
 
 
 
68
  cursor.execute(
69
  (
70
  "INSERT INTO summaries (entry_id, summary, summarizer_name)"
 
33
 
34
  @dataclass
35
  class JobOutput:
36
+ processed: str
37
  summary: str
38
  tags: list[str]
39
  processor_name: str
 
55
  summary = summarizer(processed)
56
 
57
  return JobOutput(
58
+ processed=processed,
59
  summary=summary,
60
  tags=tags,
61
  processor_name=processor_name,
 
66
 
67
  def store(job: JobInput, output: JobOutput) -> None:
68
  with get_db_cursor() as cursor:
69
+ cursor.execute(
70
+ "INSERT INTO inputs (entry_id, input) VALUES (?, ?)",
71
+ (job.id, output.processed),
72
+ )
73
  cursor.execute(
74
  (
75
  "INSERT INTO summaries (entry_id, summary, summarizer_name)"
tests/{test_webservice.py → test_app.py} RENAMED
@@ -19,6 +19,13 @@ class TestWebservice:
19
  filename = tmp_path / "test-db.sqlite"
20
  os.environ["DB_FILE_NAME"] = str(filename)
21
 
 
 
 
 
 
 
 
22
  @pytest.fixture
23
  def client(self):
24
  from gistillery.webservice import app
@@ -117,6 +124,31 @@ class TestWebservice:
117
  }
118
  assert last_updated is None
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def test_submitted_job_status_done(self, client, mlregistry, monkeypatch):
121
  # monkeypatch uuid4 to return a known value
122
  job_id = "abc1234"
@@ -182,26 +214,17 @@ class TestWebservice:
182
  assert resp0["summary"] == "this would"
183
  assert resp0["tags"] == sorted(["#this", "#would", "#be"])
184
 
185
- def test_submitted_job_failed(self, client, mlregistry, monkeypatch):
186
- # monkeypatch uuid4 to return a known value
187
- job_id = "abc1234"
188
- monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
189
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
190
- # patch gistillery.worker._process_job to raise an exception
 
191
 
192
- def raise_(ex):
193
- raise ex
194
 
195
- monkeypatch.setattr(
196
- "gistillery.worker._process_job",
197
- lambda job, registry: raise_(RuntimeError("something went wrong")),
198
- )
199
  self.process_jobs(mlregistry)
200
-
201
- resp = client.get(f"/check_job_status/{job_id}")
202
- output = resp.json()
203
- output.pop("last_updated")
204
- assert output == {
205
- "id": job_id,
206
- "status": "failed",
207
- }
 
19
  filename = tmp_path / "test-db.sqlite"
20
  os.environ["DB_FILE_NAME"] = str(filename)
21
 
22
+ @pytest.fixture
23
+ def cursor(self):
24
+ from gistillery.db import get_db_cursor
25
+
26
+ with get_db_cursor() as cursor:
27
+ yield cursor
28
+
29
  @pytest.fixture
30
  def client(self):
31
  from gistillery.webservice import app
 
124
  }
125
  assert last_updated is None
126
 
127
+ def test_submitted_job_failed(self, client, mlregistry, monkeypatch):
128
+ # monkeypatch uuid4 to return a known value
129
+ job_id = "abc1234"
130
+ monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
131
+ client.post("/submit", json={"author": "ben", "content": "this is a test"})
132
+ # patch gistillery.worker._process_job to raise an exception
133
+
134
+ def raise_(ex):
135
+ raise ex
136
+
137
+ # make the job processing fail
138
+ monkeypatch.setattr(
139
+ "gistillery.worker._process_job",
140
+ lambda job, registry: raise_(RuntimeError("something went wrong")),
141
+ )
142
+ self.process_jobs(mlregistry)
143
+
144
+ resp = client.get(f"/check_job_status/{job_id}")
145
+ output = resp.json()
146
+ output.pop("last_updated")
147
+ assert output == {
148
+ "id": job_id,
149
+ "status": "failed",
150
+ }
151
+
152
  def test_submitted_job_status_done(self, client, mlregistry, monkeypatch):
153
  # monkeypatch uuid4 to return a known value
154
  job_id = "abc1234"
 
214
  assert resp0["summary"] == "this would"
215
  assert resp0["tags"] == sorted(["#this", "#would", "#be"])
216
 
217
+ def test_clear(self, client, cursor, mlregistry):
 
 
 
218
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
219
+ self.process_jobs(mlregistry)
220
+ assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1
221
 
222
+ client.get("/clear")
223
+ assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0
224
 
225
+ def test_inputs_stored(self, client, cursor, mlregistry):
226
+ client.post("/submit", json={"author": "ben", "content": " this is a test\n"})
 
 
227
  self.process_jobs(mlregistry)
228
+ rows = cursor.execute("SELECT * FROM inputs").fetchall()
229
+ assert len(rows) == 1
230
+ assert rows[0].input == "this is a test"