Spaces:
Runtime error
Runtime error
import datetime as dt | |
import os | |
from types import SimpleNamespace | |
import pytest | |
from fastapi.testclient import TestClient | |
def is_roughly_now(datetime_str): | |
"""Check if a datetime string is roughly from now""" | |
now = dt.datetime.utcnow() | |
datetime = dt.datetime.fromisoformat(datetime_str) | |
return (now - datetime).total_seconds() < 3 | |
class TestWebservice: | |
def db_file(self, tmp_path): | |
filename = tmp_path / "test-db.sqlite" | |
os.environ["DB_FILE_NAME"] = str(filename) | |
def cursor(self): | |
from gistillery.db import get_db_cursor | |
with get_db_cursor() as cursor: | |
yield cursor | |
def client(self): | |
from gistillery.webservice import app | |
client = TestClient(app) | |
client.get("/clear") | |
return client | |
def mlregistry(self): | |
# use dummy models | |
from gistillery.ml import Summarizer, Tagger | |
from gistillery.preprocessing import RawTextProcessor | |
from gistillery.registry import MlRegistry | |
class DummySummarizer(Summarizer): | |
"""Returns the first 10 characters of the input""" | |
def __init__(self, *args, **kwargs): | |
pass | |
def get_name(self): | |
return "dummy summarizer" | |
def __call__(self, x): | |
return x[:10] | |
class DummyTagger(Tagger): | |
"""Returns the first 3 words of the input""" | |
def __init__(self, *args, **kwargs): | |
pass | |
def get_name(self): | |
return "dummy tagger" | |
def __call__(self, x): | |
return ["#" + word for word in x.split(maxsplit=4)[:3]] | |
registry = MlRegistry() | |
registry.register_processor(RawTextProcessor()) | |
# arguments don't matter for dummy summarizer and tagger | |
summarizer = DummySummarizer(None, None, None, None) | |
registry.register_summarizer(summarizer) | |
tagger = DummyTagger(None, None, None, None) | |
registry.register_tagger(tagger) | |
return registry | |
def process_jobs(self, registry): | |
# emulate work of the background worker | |
from gistillery.worker import check_pending_jobs, process_job | |
jobs = check_pending_jobs() | |
for job in jobs: | |
process_job(job, registry) | |
def test_status(self, client): | |
resp = client.get("/status") | |
assert resp.status_code == 200 | |
assert resp.json() == "OK" | |
def test_recent_empty(self, client): | |
resp = client.get("/recent") | |
assert resp.json() == [] | |
def test_recent_tag_empty(self, client): | |
resp = client.get("/recent/general") | |
assert resp.json() == [] | |
def test_submitted_job_status_pending(self, client, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
resp = client.get(f"/check_job_status/{job_id}") | |
output = resp.json() | |
last_updated = output.pop("last_updated") | |
assert output == { | |
"id": job_id, | |
"status": "pending", | |
} | |
assert is_roughly_now(last_updated) | |
def test_submitted_job_status_not_found(self, client, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
other_job_id = "def5678" | |
resp = client.get(f"/check_job_status/{other_job_id}") | |
output = resp.json() | |
last_updated = output.pop("last_updated") | |
assert output == { | |
"id": other_job_id, | |
"status": "not found", | |
} | |
assert last_updated is None | |
def test_submitted_job_failed(self, client, mlregistry, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
# patch gistillery.worker._process_job to raise an exception | |
def raise_(ex): | |
raise ex | |
# make the job processing fail | |
monkeypatch.setattr( | |
"gistillery.worker._process_job", | |
lambda job, registry: raise_(RuntimeError("something went wrong")), | |
) | |
self.process_jobs(mlregistry) | |
resp = client.get(f"/check_job_status/{job_id}") | |
output = resp.json() | |
output.pop("last_updated") | |
assert output == { | |
"id": job_id, | |
"status": "failed", | |
} | |
def test_submitted_job_status_done(self, client, mlregistry, monkeypatch): | |
# monkeypatch uuid4 to return a known value | |
job_id = "abc1234" | |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id)) | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
self.process_jobs(mlregistry) | |
resp = client.get(f"/check_job_status/{job_id}") | |
output = resp.json() | |
last_updated = output.pop("last_updated") | |
assert output == { | |
"id": job_id, | |
"status": "done", | |
} | |
assert is_roughly_now(last_updated) | |
def test_recent_with_entries(self, client, mlregistry): | |
# submit 2 entries | |
client.post( | |
"/submit", json={"author": "maxi", "content": "this is a first test"} | |
) | |
client.post( | |
"/submit", | |
json={"author": "mini", "content": "this would be something else"}, | |
) | |
self.process_jobs(mlregistry) | |
resp = client.get("/recent").json() | |
# results are sorted by recency but since dummy models are so fast, the | |
# date in the db could be the same, so we sort by author | |
resp = sorted(resp, key=lambda x: x["author"]) | |
assert len(resp) == 2 | |
resp0 = resp[0] | |
assert resp0["author"] == "maxi" | |
assert resp0["summary"] == "this is a " | |
assert resp0["tags"] == sorted(["#this", "#is", "#a"]) | |
resp1 = resp[1] | |
assert resp1["author"] == "mini" | |
assert resp1["summary"] == "this would" | |
assert resp1["tags"] == sorted(["#this", "#would", "#be"]) | |
def test_recent_tag_with_entries(self, client, mlregistry): | |
# submit 2 entries | |
client.post( | |
"/submit", json={"author": "maxi", "content": "this is a first test"} | |
) | |
client.post( | |
"/submit", | |
json={"author": "mini", "content": "this would be something else"}, | |
) | |
self.process_jobs(mlregistry) | |
# the "this" tag is in both entries | |
resp = client.get("/recent/this").json() | |
assert len(resp) == 2 | |
# the "would" tag is in only one entry | |
resp = client.get("/recent/would").json() | |
assert len(resp) == 1 | |
resp0 = resp[0] | |
assert resp0["author"] == "mini" | |
assert resp0["summary"] == "this would" | |
assert resp0["tags"] == sorted(["#this", "#would", "#be"]) | |
def test_clear(self, client, cursor, mlregistry): | |
client.post("/submit", json={"author": "ben", "content": "this is a test"}) | |
self.process_jobs(mlregistry) | |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1 | |
client.get("/clear") | |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0 | |
def test_inputs_stored(self, client, cursor, mlregistry): | |
client.post("/submit", json={"author": "ben", "content": " this is a test\n"}) | |
self.process_jobs(mlregistry) | |
rows = cursor.execute("SELECT * FROM inputs").fetchall() | |
assert len(rows) == 1 | |
assert rows[0].input == "this is a test" | |