gistillery / tests /test_app.py
Benjamin Bossan
Apply mypy, black, ruff
c19ef6e
raw
history blame
8.07 kB
import datetime as dt
import os
from types import SimpleNamespace
import pytest
from fastapi.testclient import TestClient
def is_roughly_now(datetime_str):
"""Check if a datetime string is roughly from now"""
now = dt.datetime.utcnow()
datetime = dt.datetime.fromisoformat(datetime_str)
return (now - datetime).total_seconds() < 3
class TestWebservice:
@pytest.fixture(autouse=True)
def db_file(self, tmp_path):
filename = tmp_path / "test-db.sqlite"
os.environ["DB_FILE_NAME"] = str(filename)
@pytest.fixture
def cursor(self):
from gistillery.db import get_db_cursor
with get_db_cursor() as cursor:
yield cursor
@pytest.fixture
def client(self):
from gistillery.webservice import app
client = TestClient(app)
client.get("/clear")
return client
@pytest.fixture
def mlregistry(self):
# use dummy models
from gistillery.ml import Summarizer, Tagger
from gistillery.preprocessing import RawTextProcessor
from gistillery.registry import MlRegistry
class DummySummarizer(Summarizer):
"""Returns the first 10 characters of the input"""
def __init__(self, *args, **kwargs):
pass
def get_name(self):
return "dummy summarizer"
def __call__(self, x):
return x[:10]
class DummyTagger(Tagger):
"""Returns the first 3 words of the input"""
def __init__(self, *args, **kwargs):
pass
def get_name(self):
return "dummy tagger"
def __call__(self, x):
return ["#" + word for word in x.split(maxsplit=4)[:3]]
registry = MlRegistry()
registry.register_processor(RawTextProcessor())
# arguments don't matter for dummy summarizer and tagger
summarizer = DummySummarizer(None, None, None, None)
registry.register_summarizer(summarizer)
tagger = DummyTagger(None, None, None, None)
registry.register_tagger(tagger)
return registry
def process_jobs(self, registry):
# emulate work of the background worker
from gistillery.worker import check_pending_jobs, process_job
jobs = check_pending_jobs()
for job in jobs:
process_job(job, registry)
def test_status(self, client):
resp = client.get("/status")
assert resp.status_code == 200
assert resp.json() == "OK"
def test_recent_empty(self, client):
resp = client.get("/recent")
assert resp.json() == []
def test_recent_tag_empty(self, client):
resp = client.get("/recent/general")
assert resp.json() == []
def test_submitted_job_status_pending(self, client, monkeypatch):
# monkeypatch uuid4 to return a known value
job_id = "abc1234"
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
client.post("/submit", json={"author": "ben", "content": "this is a test"})
resp = client.get(f"/check_job_status/{job_id}")
output = resp.json()
last_updated = output.pop("last_updated")
assert output == {
"id": job_id,
"status": "pending",
}
assert is_roughly_now(last_updated)
def test_submitted_job_status_not_found(self, client, monkeypatch):
# monkeypatch uuid4 to return a known value
job_id = "abc1234"
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
client.post("/submit", json={"author": "ben", "content": "this is a test"})
other_job_id = "def5678"
resp = client.get(f"/check_job_status/{other_job_id}")
output = resp.json()
last_updated = output.pop("last_updated")
assert output == {
"id": other_job_id,
"status": "not found",
}
assert last_updated is None
def test_submitted_job_failed(self, client, mlregistry, monkeypatch):
# monkeypatch uuid4 to return a known value
job_id = "abc1234"
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
client.post("/submit", json={"author": "ben", "content": "this is a test"})
# patch gistillery.worker._process_job to raise an exception
def raise_(ex):
raise ex
# make the job processing fail
monkeypatch.setattr(
"gistillery.worker._process_job",
lambda job, registry: raise_(RuntimeError("something went wrong")),
)
self.process_jobs(mlregistry)
resp = client.get(f"/check_job_status/{job_id}")
output = resp.json()
output.pop("last_updated")
assert output == {
"id": job_id,
"status": "failed",
}
def test_submitted_job_status_done(self, client, mlregistry, monkeypatch):
# monkeypatch uuid4 to return a known value
job_id = "abc1234"
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
client.post("/submit", json={"author": "ben", "content": "this is a test"})
self.process_jobs(mlregistry)
resp = client.get(f"/check_job_status/{job_id}")
output = resp.json()
last_updated = output.pop("last_updated")
assert output == {
"id": job_id,
"status": "done",
}
assert is_roughly_now(last_updated)
def test_recent_with_entries(self, client, mlregistry):
# submit 2 entries
client.post(
"/submit", json={"author": "maxi", "content": "this is a first test"}
)
client.post(
"/submit",
json={"author": "mini", "content": "this would be something else"},
)
self.process_jobs(mlregistry)
resp = client.get("/recent").json()
# results are sorted by recency but since dummy models are so fast, the
# date in the db could be the same, so we sort by author
resp = sorted(resp, key=lambda x: x["author"])
assert len(resp) == 2
resp0 = resp[0]
assert resp0["author"] == "maxi"
assert resp0["summary"] == "this is a "
assert resp0["tags"] == sorted(["#this", "#is", "#a"])
resp1 = resp[1]
assert resp1["author"] == "mini"
assert resp1["summary"] == "this would"
assert resp1["tags"] == sorted(["#this", "#would", "#be"])
def test_recent_tag_with_entries(self, client, mlregistry):
# submit 2 entries
client.post(
"/submit", json={"author": "maxi", "content": "this is a first test"}
)
client.post(
"/submit",
json={"author": "mini", "content": "this would be something else"},
)
self.process_jobs(mlregistry)
# the "this" tag is in both entries
resp = client.get("/recent/this").json()
assert len(resp) == 2
# the "would" tag is in only one entry
resp = client.get("/recent/would").json()
assert len(resp) == 1
resp0 = resp[0]
assert resp0["author"] == "mini"
assert resp0["summary"] == "this would"
assert resp0["tags"] == sorted(["#this", "#would", "#be"])
def test_clear(self, client, cursor, mlregistry):
client.post("/submit", json={"author": "ben", "content": "this is a test"})
self.process_jobs(mlregistry)
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1
client.get("/clear")
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0
def test_inputs_stored(self, client, cursor, mlregistry):
client.post("/submit", json={"author": "ben", "content": " this is a test\n"})
self.process_jobs(mlregistry)
rows = cursor.execute("SELECT * FROM inputs").fetchall()
assert len(rows) == 1
assert rows[0].input == "this is a test"