Spaces:
Runtime error
Runtime error
Push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +4 -1
- requirements.txt +40 -39
- src/__pycache__/conftest.cpython-39-pytest-7.4.0.pyc +0 -0
- src/__pycache__/data_loader.cpython-39.pyc +0 -0
- src/__pycache__/data_loader_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/__pycache__/router_concept.cpython-39.pyc +0 -0
- src/__pycache__/router_dataset.cpython-39.pyc +0 -0
- src/__pycache__/router_signal.cpython-39.pyc +0 -0
- src/__pycache__/schema.cpython-39.pyc +0 -0
- src/__pycache__/schema_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/__pycache__/server.cpython-39.pyc +0 -0
- src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc +0 -0
- src/__pycache__/server_concept_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/__pycache__/server_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/__pycache__/tasks.cpython-39.pyc +0 -0
- src/__pycache__/test_utils.cpython-39-pytest-7.4.0.pyc +0 -0
- src/concepts/__pycache__/concept.cpython-39.pyc +0 -0
- src/concepts/__pycache__/concept_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/concepts/__pycache__/db_concept.cpython-39.pyc +0 -0
- src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc +0 -0
- src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/concepts/concept.py +167 -50
- src/concepts/db_concept.py +80 -12
- src/concepts/db_concept_test.py +33 -27
- src/data/__pycache__/dataset.cpython-39.pyc +0 -0
- src/data/__pycache__/dataset_compute_signal_chain_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_compute_signal_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_duckdb.cpython-39.pyc +0 -0
- src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc +0 -0
- src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc +0 -0
- src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_select_rows_schema_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc +0 -0
- src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_select_rows_sort_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_select_rows_udf_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc +0 -0
- src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/dataset_utils.cpython-39.pyc +0 -0
- src/data/__pycache__/dataset_utils_test.cpython-39-pytest-7.4.0.pyc +0 -0
- src/data/__pycache__/duckdb_utils.cpython-39.pyc +0 -0
- src/data/dataset.py +1 -1
- src/data/dataset_duckdb.py +59 -33
- src/data/dataset_select_groups_test.py +9 -5
- src/data/dataset_select_rows_filter_test.py +88 -0
- src/data/dataset_select_rows_search_test.py +1 -9
- src/data/dataset_stats_test.py +5 -2
- src/data/dataset_utils.py +5 -1
Dockerfile
CHANGED
@@ -13,7 +13,7 @@ COPY requirements.txt .
|
|
13 |
RUN pip install --no-cache-dir -r requirements.txt
|
14 |
|
15 |
# Copy the data to /data, the HF persistent storage. We do this after pip install to avoid
|
16 |
-
# re-installing dependencies if the data changes.
|
17 |
WORKDIR /
|
18 |
COPY /data /data
|
19 |
WORKDIR /server
|
@@ -27,4 +27,7 @@ COPY /web/blueprint/build ./web/blueprint/build
|
|
27 |
# Copy python files.
|
28 |
COPY /src ./src/
|
29 |
|
|
|
|
|
|
|
30 |
CMD ["uvicorn", "src.server:app", "--host", "0.0.0.0", "--port", "5432"]
|
|
|
13 |
RUN pip install --no-cache-dir -r requirements.txt
|
14 |
|
15 |
# Copy the data to /data, the HF persistent storage. We do this after pip install to avoid
|
16 |
+
# re-installing dependencies if the data changes, which is likely more often.
|
17 |
WORKDIR /
|
18 |
COPY /data /data
|
19 |
WORKDIR /server
|
|
|
27 |
# Copy python files.
|
28 |
COPY /src ./src/
|
29 |
|
30 |
+
# Copy the entrypoint file.
|
31 |
+
COPY docker_entrypoint.sh .
|
32 |
+
|
33 |
CMD ["uvicorn", "src.server:app", "--host", "0.0.0.0", "--port", "5432"]
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
aiohttp==3.8.4 ; python_version >= "3.9" and python_version < "3.10"
|
2 |
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
|
3 |
-
anyio==3.7.
|
4 |
async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "3.10"
|
5 |
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
6 |
blis==0.7.9 ; python_version >= "3.9" and python_version < "3.10"
|
@@ -12,43 +12,43 @@ click==8.1.3 ; python_version >= "3.9" and python_version < "3.10"
|
|
12 |
cloudpickle==2.2.1 ; python_version >= "3.9" and python_version < "3.10"
|
13 |
cohere==3.10.0 ; python_version >= "3.9" and python_version < "3.10"
|
14 |
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.10" and (platform_system == "Windows" or sys_platform == "win32")
|
15 |
-
confection==0.0
|
16 |
cymem==2.0.7 ; python_version >= "3.9" and python_version < "3.10"
|
17 |
cytoolz==0.12.1 ; python_version >= "3.9" and python_version < "3.10"
|
18 |
-
dask==2023.
|
19 |
-
datasets==2.
|
20 |
decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
|
21 |
dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
|
22 |
-
distributed==2023.
|
23 |
-
duckdb==0.8.
|
24 |
email-reply-parser==0.5.12 ; python_version >= "3.9" and python_version < "3.10"
|
25 |
-
exceptiongroup==1.1.
|
26 |
-
fastapi==0.
|
27 |
-
filelock==3.12.
|
28 |
floret==0.10.3 ; python_version >= "3.9" and python_version < "3.10"
|
29 |
frozenlist==1.3.3 ; python_version >= "3.9" and python_version < "3.10"
|
30 |
-
fsspec==2023.
|
31 |
-
fsspec[http]==2023.
|
32 |
-
gcsfs==2023.
|
33 |
-
google-api-core==2.11.
|
34 |
-
google-api-python-client==2.
|
35 |
google-auth-httplib2==0.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
36 |
google-auth-oauthlib==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
37 |
-
google-auth==2.
|
38 |
-
google-cloud-core==2.3.
|
39 |
-
google-cloud-storage==2.
|
40 |
google-crc32c==1.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
41 |
google-resumable-media==2.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
42 |
-
googleapis-common-protos==1.59.
|
43 |
h11==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
|
44 |
httplib2==0.22.0 ; python_version >= "3.9" and python_version < "3.10"
|
45 |
httptools==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
46 |
huggingface-hub==0.15.1 ; python_version >= "3.9" and python_version < "3.10"
|
47 |
idna==3.4 ; python_version >= "3.9" and python_version < "3.10"
|
48 |
-
importlib-metadata==6.
|
49 |
-
jellyfish==0.
|
50 |
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.10"
|
51 |
-
joblib==1.
|
52 |
langcodes==3.3.0 ; python_version >= "3.9" and python_version < "3.10"
|
53 |
locket==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
54 |
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.10"
|
@@ -59,24 +59,24 @@ multiprocess==0.70.14 ; python_version >= "3.9" and python_version < "3.10"
|
|
59 |
murmurhash==1.0.9 ; python_version >= "3.9" and python_version < "3.10"
|
60 |
networkx==3.1 ; python_version >= "3.9" and python_version < "3.10"
|
61 |
nltk==3.8.1 ; python_version >= "3.9" and python_version < "3.10"
|
62 |
-
numpy==1.
|
63 |
oauthlib==3.2.2 ; python_version >= "3.9" and python_version < "3.10"
|
64 |
openai-function-call==0.0.5 ; python_version >= "3.9" and python_version < "3.10"
|
65 |
openai==0.27.8 ; python_version >= "3.9" and python_version < "3.10"
|
66 |
-
orjson==3.9.
|
67 |
packaging==23.1 ; python_version >= "3.9" and python_version < "3.10"
|
68 |
-
pandas==2.0.
|
69 |
partd==1.4.0 ; python_version >= "3.9" and python_version < "3.10"
|
70 |
-
pathy==0.10.
|
71 |
pillow==9.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
72 |
preshed==3.0.8 ; python_version >= "3.9" and python_version < "3.10"
|
73 |
-
protobuf==4.23.
|
74 |
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.10"
|
75 |
pyarrow==9.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
76 |
pyasn1-modules==0.3.0 ; python_version >= "3.9" and python_version < "3.10"
|
77 |
pyasn1==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
78 |
-
pydantic==1.10.
|
79 |
-
pyparsing==3.0
|
80 |
pyphen==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
|
81 |
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.10"
|
82 |
python-dotenv==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
@@ -85,10 +85,10 @@ pyyaml==6.0 ; python_version >= "3.9" and python_version < "3.10"
|
|
85 |
regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.10"
|
86 |
requests-oauthlib==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
|
87 |
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.10"
|
88 |
-
responses==0.18.0 ; python_version >= "3.9" and python_version < "3.10"
|
89 |
rsa==4.9 ; python_version >= "3.9" and python_version < "3.10"
|
90 |
-
|
91 |
-
|
|
|
92 |
sentence-transformers==2.2.2 ; python_version >= "3.9" and python_version < "3.10"
|
93 |
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.10"
|
94 |
setuptools==65.7.0 ; python_version >= "3.9" and python_version < "3.10"
|
@@ -98,11 +98,12 @@ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "3.10"
|
|
98 |
sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "3.10"
|
99 |
spacy-legacy==3.0.12 ; python_version >= "3.9" and python_version < "3.10"
|
100 |
spacy-loggers==1.0.4 ; python_version >= "3.9" and python_version < "3.10"
|
101 |
-
spacy==3.5.
|
102 |
srsly==2.4.6 ; python_version >= "3.9" and python_version < "3.10"
|
103 |
starlette==0.27.0 ; python_version >= "3.9" and python_version < "3.10"
|
104 |
sympy==1.12 ; python_version >= "3.9" and python_version < "3.10"
|
105 |
-
tblib==
|
|
|
106 |
textacy==0.13.0 ; python_version >= "3.9" and python_version < "3.10"
|
107 |
thinc==8.1.10 ; python_version >= "3.9" and python_version < "3.10"
|
108 |
threadpoolctl==3.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
@@ -112,16 +113,16 @@ torch==2.0.1 ; python_version >= "3.9" and python_version < "3.10"
|
|
112 |
torchvision==0.15.2 ; python_version >= "3.9" and python_version < "3.10"
|
113 |
tornado==6.3.2 ; python_version >= "3.9" and python_version < "3.10"
|
114 |
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.10"
|
115 |
-
transformers==4.
|
116 |
-
typer==0.
|
117 |
-
types-psutil==5.9.5.
|
118 |
-
typing-extensions==4.
|
119 |
tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.10"
|
120 |
uritemplate==4.1.1 ; python_version >= "3.9" and python_version < "3.10"
|
121 |
urllib3==1.26.16 ; python_version >= "3.9" and python_version < "3.10"
|
122 |
-
uvicorn[standard]==0.
|
123 |
uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "3.10"
|
124 |
-
wasabi==1.1.
|
125 |
watchfiles==0.19.0 ; python_version >= "3.9" and python_version < "3.10"
|
126 |
websockets==11.0.3 ; python_version >= "3.9" and python_version < "3.10"
|
127 |
xxhash==3.2.0 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
1 |
aiohttp==3.8.4 ; python_version >= "3.9" and python_version < "3.10"
|
2 |
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
|
3 |
+
anyio==3.7.1 ; python_version >= "3.9" and python_version < "3.10"
|
4 |
async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "3.10"
|
5 |
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
6 |
blis==0.7.9 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
12 |
cloudpickle==2.2.1 ; python_version >= "3.9" and python_version < "3.10"
|
13 |
cohere==3.10.0 ; python_version >= "3.9" and python_version < "3.10"
|
14 |
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.10" and (platform_system == "Windows" or sys_platform == "win32")
|
15 |
+
confection==0.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
16 |
cymem==2.0.7 ; python_version >= "3.9" and python_version < "3.10"
|
17 |
cytoolz==0.12.1 ; python_version >= "3.9" and python_version < "3.10"
|
18 |
+
dask==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
|
19 |
+
datasets==2.13.1 ; python_version >= "3.9" and python_version < "3.10"
|
20 |
decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
|
21 |
dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
|
22 |
+
distributed==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
|
23 |
+
duckdb==0.8.1 ; python_version >= "3.9" and python_version < "3.10"
|
24 |
email-reply-parser==0.5.12 ; python_version >= "3.9" and python_version < "3.10"
|
25 |
+
exceptiongroup==1.1.2 ; python_version >= "3.9" and python_version < "3.10"
|
26 |
+
fastapi==0.98.0 ; python_version >= "3.9" and python_version < "3.10"
|
27 |
+
filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.10"
|
28 |
floret==0.10.3 ; python_version >= "3.9" and python_version < "3.10"
|
29 |
frozenlist==1.3.3 ; python_version >= "3.9" and python_version < "3.10"
|
30 |
+
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.10"
|
31 |
+
fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "3.10"
|
32 |
+
gcsfs==2023.6.0 ; python_version >= "3.9" and python_version < "3.10"
|
33 |
+
google-api-core==2.11.1 ; python_version >= "3.9" and python_version < "3.10"
|
34 |
+
google-api-python-client==2.92.0 ; python_version >= "3.9" and python_version < "3.10"
|
35 |
google-auth-httplib2==0.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
36 |
google-auth-oauthlib==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
37 |
+
google-auth==2.21.0 ; python_version >= "3.9" and python_version < "3.10"
|
38 |
+
google-cloud-core==2.3.3 ; python_version >= "3.9" and python_version < "3.10"
|
39 |
+
google-cloud-storage==2.10.0 ; python_version >= "3.9" and python_version < "3.10"
|
40 |
google-crc32c==1.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
41 |
google-resumable-media==2.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
42 |
+
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "3.10"
|
43 |
h11==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
|
44 |
httplib2==0.22.0 ; python_version >= "3.9" and python_version < "3.10"
|
45 |
httptools==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
46 |
huggingface-hub==0.15.1 ; python_version >= "3.9" and python_version < "3.10"
|
47 |
idna==3.4 ; python_version >= "3.9" and python_version < "3.10"
|
48 |
+
importlib-metadata==6.7.0 ; python_version >= "3.9" and python_version < "3.10"
|
49 |
+
jellyfish==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
50 |
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.10"
|
51 |
+
joblib==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
|
52 |
langcodes==3.3.0 ; python_version >= "3.9" and python_version < "3.10"
|
53 |
locket==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
54 |
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
59 |
murmurhash==1.0.9 ; python_version >= "3.9" and python_version < "3.10"
|
60 |
networkx==3.1 ; python_version >= "3.9" and python_version < "3.10"
|
61 |
nltk==3.8.1 ; python_version >= "3.9" and python_version < "3.10"
|
62 |
+
numpy==1.25.0 ; python_version >= "3.9" and python_version < "3.10"
|
63 |
oauthlib==3.2.2 ; python_version >= "3.9" and python_version < "3.10"
|
64 |
openai-function-call==0.0.5 ; python_version >= "3.9" and python_version < "3.10"
|
65 |
openai==0.27.8 ; python_version >= "3.9" and python_version < "3.10"
|
66 |
+
orjson==3.9.1 ; python_version >= "3.9" and python_version < "3.10"
|
67 |
packaging==23.1 ; python_version >= "3.9" and python_version < "3.10"
|
68 |
+
pandas==2.0.3 ; python_version >= "3.9" and python_version < "3.10"
|
69 |
partd==1.4.0 ; python_version >= "3.9" and python_version < "3.10"
|
70 |
+
pathy==0.10.2 ; python_version >= "3.9" and python_version < "3.10"
|
71 |
pillow==9.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
72 |
preshed==3.0.8 ; python_version >= "3.9" and python_version < "3.10"
|
73 |
+
protobuf==4.23.3 ; python_version >= "3.9" and python_version < "3.10"
|
74 |
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.10"
|
75 |
pyarrow==9.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
76 |
pyasn1-modules==0.3.0 ; python_version >= "3.9" and python_version < "3.10"
|
77 |
pyasn1==0.5.0 ; python_version >= "3.9" and python_version < "3.10"
|
78 |
+
pydantic==1.10.11 ; python_version >= "3.9" and python_version < "3.10"
|
79 |
+
pyparsing==3.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
80 |
pyphen==0.14.0 ; python_version >= "3.9" and python_version < "3.10"
|
81 |
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "3.10"
|
82 |
python-dotenv==1.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
85 |
regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.10"
|
86 |
requests-oauthlib==1.3.1 ; python_version >= "3.9" and python_version < "3.10"
|
87 |
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
88 |
rsa==4.9 ; python_version >= "3.9" and python_version < "3.10"
|
89 |
+
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.10"
|
90 |
+
scikit-learn==1.3.0 ; python_version >= "3.9" and python_version < "3.10"
|
91 |
+
scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.10"
|
92 |
sentence-transformers==2.2.2 ; python_version >= "3.9" and python_version < "3.10"
|
93 |
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.10"
|
94 |
setuptools==65.7.0 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
98 |
sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "3.10"
|
99 |
spacy-legacy==3.0.12 ; python_version >= "3.9" and python_version < "3.10"
|
100 |
spacy-loggers==1.0.4 ; python_version >= "3.9" and python_version < "3.10"
|
101 |
+
spacy==3.5.4 ; python_version >= "3.9" and python_version < "3.10"
|
102 |
srsly==2.4.6 ; python_version >= "3.9" and python_version < "3.10"
|
103 |
starlette==0.27.0 ; python_version >= "3.9" and python_version < "3.10"
|
104 |
sympy==1.12 ; python_version >= "3.9" and python_version < "3.10"
|
105 |
+
tblib==2.0.0 ; python_version >= "3.9" and python_version < "3.10"
|
106 |
+
tenacity==8.2.2 ; python_version >= "3.9" and python_version < "3.10"
|
107 |
textacy==0.13.0 ; python_version >= "3.9" and python_version < "3.10"
|
108 |
thinc==8.1.10 ; python_version >= "3.9" and python_version < "3.10"
|
109 |
threadpoolctl==3.1.0 ; python_version >= "3.9" and python_version < "3.10"
|
|
|
113 |
torchvision==0.15.2 ; python_version >= "3.9" and python_version < "3.10"
|
114 |
tornado==6.3.2 ; python_version >= "3.9" and python_version < "3.10"
|
115 |
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.10"
|
116 |
+
transformers==4.30.2 ; python_version >= "3.9" and python_version < "3.10"
|
117 |
+
typer==0.9.0 ; python_version >= "3.9" and python_version < "3.10"
|
118 |
+
types-psutil==5.9.5.15 ; python_version >= "3.9" and python_version < "3.10"
|
119 |
+
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.10"
|
120 |
tzdata==2023.3 ; python_version >= "3.9" and python_version < "3.10"
|
121 |
uritemplate==4.1.1 ; python_version >= "3.9" and python_version < "3.10"
|
122 |
urllib3==1.26.16 ; python_version >= "3.9" and python_version < "3.10"
|
123 |
+
uvicorn[standard]==0.22.0 ; python_version >= "3.9" and python_version < "3.10"
|
124 |
uvloop==0.17.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "3.10"
|
125 |
+
wasabi==1.1.2 ; python_version >= "3.9" and python_version < "3.10"
|
126 |
watchfiles==0.19.0 ; python_version >= "3.9" and python_version < "3.10"
|
127 |
websockets==11.0.3 ; python_version >= "3.9" and python_version < "3.10"
|
128 |
xxhash==3.2.0 ; python_version >= "3.9" and python_version < "3.10"
|
src/__pycache__/conftest.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (1.41 kB). View file
|
|
src/__pycache__/data_loader.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/data_loader.cpython-39.pyc and b/src/__pycache__/data_loader.cpython-39.pyc differ
|
|
src/__pycache__/data_loader_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (4.35 kB). View file
|
|
src/__pycache__/router_concept.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/router_concept.cpython-39.pyc and b/src/__pycache__/router_concept.cpython-39.pyc differ
|
|
src/__pycache__/router_dataset.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/router_dataset.cpython-39.pyc and b/src/__pycache__/router_dataset.cpython-39.pyc differ
|
|
src/__pycache__/router_signal.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/router_signal.cpython-39.pyc and b/src/__pycache__/router_signal.cpython-39.pyc differ
|
|
src/__pycache__/schema.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/schema.cpython-39.pyc and b/src/__pycache__/schema.cpython-39.pyc differ
|
|
src/__pycache__/schema_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (9.4 kB). View file
|
|
src/__pycache__/server.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/server.cpython-39.pyc and b/src/__pycache__/server.cpython-39.pyc differ
|
|
src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc
CHANGED
Binary files a/src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc and b/src/__pycache__/server_concept_test.cpython-39-pytest-7.3.1.pyc differ
|
|
src/__pycache__/server_concept_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (19.2 kB). View file
|
|
src/__pycache__/server_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (12.5 kB). View file
|
|
src/__pycache__/tasks.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/tasks.cpython-39.pyc and b/src/__pycache__/tasks.cpython-39.pyc differ
|
|
src/__pycache__/test_utils.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (1.1 kB). View file
|
|
src/concepts/__pycache__/concept.cpython-39.pyc
CHANGED
Binary files a/src/concepts/__pycache__/concept.cpython-39.pyc and b/src/concepts/__pycache__/concept.cpython-39.pyc differ
|
|
src/concepts/__pycache__/concept_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (3.74 kB). View file
|
|
src/concepts/__pycache__/db_concept.cpython-39.pyc
CHANGED
Binary files a/src/concepts/__pycache__/db_concept.cpython-39.pyc and b/src/concepts/__pycache__/db_concept.cpython-39.pyc differ
|
|
src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc
CHANGED
Binary files a/src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc and b/src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.3.1.pyc differ
|
|
src/concepts/__pycache__/db_concept_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (27.7 kB). View file
|
|
src/concepts/concept.py
CHANGED
@@ -1,23 +1,37 @@
|
|
1 |
"""Defines the concept and the concept models."""
|
|
|
2 |
import random
|
3 |
-
from
|
|
|
4 |
|
5 |
import numpy as np
|
|
|
6 |
from pydantic import BaseModel, validator
|
|
|
|
|
7 |
from sklearn.exceptions import NotFittedError
|
8 |
from sklearn.linear_model import LogisticRegression
|
|
|
|
|
9 |
|
10 |
from ..db_manager import get_dataset
|
11 |
from ..embeddings.embedding import get_embed_fn
|
12 |
from ..schema import Path, RichData, SignalInputType, normalize_path
|
13 |
-
from ..signals.signal import TextEmbeddingSignal, get_signal_cls
|
14 |
from ..utils import DebugTimer
|
15 |
|
16 |
LOCAL_CONCEPT_NAMESPACE = 'local'
|
17 |
|
18 |
# Number of randomly sampled negative examples to use for training. This is used to obtain a more
|
19 |
# balanced model that works with a specific dataset.
|
20 |
-
DEFAULT_NUM_NEG_EXAMPLES =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
class ConceptColumnInfo(BaseModel):
|
@@ -29,6 +43,13 @@ class ConceptColumnInfo(BaseModel):
|
|
29 |
# Path holding the text to use for negative examples.
|
30 |
path: Path
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
num_negative_examples = DEFAULT_NUM_NEG_EXAMPLES
|
33 |
|
34 |
|
@@ -71,7 +92,7 @@ class Example(ExampleIn):
|
|
71 |
class Concept(BaseModel):
|
72 |
"""A concept is a collection of examples."""
|
73 |
# The namespace of the concept.
|
74 |
-
namespace: str
|
75 |
# The name of the concept.
|
76 |
concept_name: str
|
77 |
# The type of the data format that this concept represents.
|
@@ -79,6 +100,8 @@ class Concept(BaseModel):
|
|
79 |
data: dict[str, Example]
|
80 |
version: int = 0
|
81 |
|
|
|
|
|
82 |
def drafts(self) -> list[DraftId]:
|
83 |
"""Gets all the drafts for the concept."""
|
84 |
drafts: set[DraftId] = set([DRAFT_MAIN]) # Always return the main draft.
|
@@ -88,39 +111,141 @@ class Concept(BaseModel):
|
|
88 |
return list(sorted(drafts))
|
89 |
|
90 |
|
91 |
-
class
|
92 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
class Config:
|
95 |
-
arbitrary_types_allowed = True
|
96 |
-
underscore_attrs_are_private = True
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
|
106 |
"""Get the scores for the provided embeddings."""
|
107 |
try:
|
108 |
-
|
|
|
|
|
|
|
109 |
except NotFittedError:
|
110 |
return np.random.rand(len(embeddings))
|
111 |
|
112 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
"""Fit the model to the provided embeddings and labels."""
|
114 |
-
|
|
|
|
|
|
|
115 |
return
|
116 |
if len(labels) != len(embeddings):
|
117 |
raise ValueError(
|
118 |
f'Length of embeddings ({len(embeddings)}) must match length of labels ({len(labels)})')
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
|
126 |
def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
|
@@ -136,7 +261,7 @@ def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
|
|
136 |
raise ValueError(
|
137 |
f'Draft {draft} not found in concept. Found drafts: {list(draft_examples.keys())}')
|
138 |
|
139 |
-
# Map the text of the draft to its id so we can
|
140 |
draft_text_ids = {example.text: id for id, example in draft_examples[draft].items()}
|
141 |
|
142 |
# Write each of examples from main to the draft examples only if the text does not appear in the
|
@@ -148,7 +273,8 @@ def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
|
|
148 |
return draft_examples[draft]
|
149 |
|
150 |
|
151 |
-
|
|
|
152 |
"""A concept model. Stores all concept model drafts and manages syncing."""
|
153 |
# The concept that this model is for.
|
154 |
namespace: str
|
@@ -158,20 +284,27 @@ class ConceptModel(BaseModel):
|
|
158 |
embedding_name: str
|
159 |
version: int = -1
|
160 |
|
161 |
-
|
|
|
|
|
162 |
# Maps a concept id to the embeddings.
|
163 |
-
_embeddings: dict[str, np.ndarray] =
|
164 |
-
_logistic_models: dict[DraftId, LogisticEmbeddingModel] =
|
165 |
_negative_vectors: Optional[np.ndarray] = None
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
-
def
|
172 |
"""Calibrate the model on the embeddings in the provided vector store."""
|
173 |
db = get_dataset(column_info.namespace, column_info.name)
|
174 |
-
vector_store = db.get_vector_store(normalize_path(column_info.path))
|
175 |
keys = vector_store.keys()
|
176 |
num_samples = min(column_info.num_negative_examples, len(keys))
|
177 |
sample_keys = random.sample(keys, num_samples)
|
@@ -199,11 +332,7 @@ class ConceptModel(BaseModel):
|
|
199 |
def _get_logistic_model(self, draft: DraftId) -> LogisticEmbeddingModel:
|
200 |
"""Get the logistic model for the provided draft."""
|
201 |
if draft not in self._logistic_models:
|
202 |
-
self._logistic_models[draft] = LogisticEmbeddingModel(
|
203 |
-
namespace=self.namespace,
|
204 |
-
concept_name=self.concept_name,
|
205 |
-
embedding_name=self.embedding_name,
|
206 |
-
version=-1)
|
207 |
return self._logistic_models[draft]
|
208 |
|
209 |
def sync(self, concept: Concept) -> bool:
|
@@ -222,21 +351,9 @@ class ConceptModel(BaseModel):
|
|
222 |
examples = draft_examples(concept, draft)
|
223 |
embeddings = np.array([self._embeddings[id] for id in examples.keys()])
|
224 |
labels = [example.label for example in examples.values()]
|
225 |
-
num_pos_labels = len([x for x in labels if x])
|
226 |
-
num_neg_labels = len([x for x in labels if not x])
|
227 |
-
sample_weights = [(1.0 / num_pos_labels if x else 1.0 / num_neg_labels) for x in labels]
|
228 |
-
if self._negative_vectors is not None:
|
229 |
-
num_implicit_labels = len(self._negative_vectors)
|
230 |
-
embeddings = np.concatenate([self._negative_vectors, embeddings])
|
231 |
-
labels = [False] * num_implicit_labels + labels
|
232 |
-
sample_weights = [1.0 / num_implicit_labels] * num_implicit_labels + sample_weights
|
233 |
-
|
234 |
model = self._get_logistic_model(draft)
|
235 |
with DebugTimer(f'Fitting model for "{concept_path}"'):
|
236 |
-
model.fit(embeddings, labels,
|
237 |
-
|
238 |
-
# Synchronize the model version with the concept version.
|
239 |
-
model.version = concept.version
|
240 |
|
241 |
# Synchronize the model version with the concept version.
|
242 |
self.version = concept.version
|
|
|
1 |
"""Defines the concept and the concept models."""
|
2 |
+
import dataclasses
|
3 |
import random
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Callable, Iterable, Literal, Optional, Union
|
6 |
|
7 |
import numpy as np
|
8 |
+
from joblib import Parallel, delayed
|
9 |
from pydantic import BaseModel, validator
|
10 |
+
from scipy.interpolate import interp1d
|
11 |
+
from sklearn.base import BaseEstimator, clone
|
12 |
from sklearn.exceptions import NotFittedError
|
13 |
from sklearn.linear_model import LogisticRegression
|
14 |
+
from sklearn.metrics import precision_recall_curve, roc_auc_score
|
15 |
+
from sklearn.model_selection import KFold
|
16 |
|
17 |
from ..db_manager import get_dataset
|
18 |
from ..embeddings.embedding import get_embed_fn
|
19 |
from ..schema import Path, RichData, SignalInputType, normalize_path
|
20 |
+
from ..signals.signal import EMBEDDING_KEY, TextEmbeddingSignal, get_signal_cls
|
21 |
from ..utils import DebugTimer
|
22 |
|
23 |
LOCAL_CONCEPT_NAMESPACE = 'local'
|
24 |
|
25 |
# Number of randomly sampled negative examples to use for training. This is used to obtain a more
|
26 |
# balanced model that works with a specific dataset.
|
27 |
+
DEFAULT_NUM_NEG_EXAMPLES = 100
|
28 |
+
|
29 |
+
# The maximum number of cross-validation models to train.
|
30 |
+
MAX_NUM_CROSS_VAL_MODELS = 15
|
31 |
+
# The β weight to use for the F-beta score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html
|
32 |
+
# β = 0.5 means we value precision 2x as much as recall.
|
33 |
+
# β = 2 means we value recall 2x as much as precision.
|
34 |
+
F_BETA_WEIGHT = 0.5
|
35 |
|
36 |
|
37 |
class ConceptColumnInfo(BaseModel):
|
|
|
43 |
# Path holding the text to use for negative examples.
|
44 |
path: Path
|
45 |
|
46 |
+
@validator('path')
|
47 |
+
def _path_points_to_text_field(cls, path: Path) -> Path:
|
48 |
+
if path[-1] == EMBEDDING_KEY:
|
49 |
+
raise ValueError(
|
50 |
+
f'The path should point to the text field, not its embedding field. Provided path: {path}')
|
51 |
+
return path
|
52 |
+
|
53 |
num_negative_examples = DEFAULT_NUM_NEG_EXAMPLES
|
54 |
|
55 |
|
|
|
92 |
class Concept(BaseModel):
|
93 |
"""A concept is a collection of examples."""
|
94 |
# The namespace of the concept.
|
95 |
+
namespace: str
|
96 |
# The name of the concept.
|
97 |
concept_name: str
|
98 |
# The type of the data format that this concept represents.
|
|
|
100 |
data: dict[str, Example]
|
101 |
version: int = 0
|
102 |
|
103 |
+
description: Optional[str] = None
|
104 |
+
|
105 |
def drafts(self) -> list[DraftId]:
|
106 |
"""Gets all the drafts for the concept."""
|
107 |
drafts: set[DraftId] = set([DRAFT_MAIN]) # Always return the main draft.
|
|
|
111 |
return list(sorted(drafts))
|
112 |
|
113 |
|
114 |
+
class OverallScore(str, Enum):
|
115 |
+
"""Enum holding the overall score."""
|
116 |
+
NOT_GOOD = 'not_good'
|
117 |
+
OK = 'ok'
|
118 |
+
GOOD = 'good'
|
119 |
+
VERY_GOOD = 'very_good'
|
120 |
+
GREAT = 'great'
|
121 |
|
|
|
|
|
|
|
122 |
|
123 |
+
def _get_overall_score(f1_score: float) -> OverallScore:
|
124 |
+
if f1_score < 0.5:
|
125 |
+
return OverallScore.NOT_GOOD
|
126 |
+
if f1_score < 0.8:
|
127 |
+
return OverallScore.OK
|
128 |
+
if f1_score < 0.9:
|
129 |
+
return OverallScore.GOOD
|
130 |
+
if f1_score < 0.95:
|
131 |
+
return OverallScore.VERY_GOOD
|
132 |
+
return OverallScore.GREAT
|
133 |
+
|
134 |
+
|
135 |
+
class ConceptMetrics(BaseModel):
|
136 |
+
"""Metrics for a concept."""
|
137 |
+
# The average F1 score for the concept computed using cross validation.
|
138 |
+
f1: float
|
139 |
+
precision: float
|
140 |
+
recall: float
|
141 |
+
roc_auc: float
|
142 |
+
overall: OverallScore
|
143 |
|
144 |
+
|
145 |
+
@dataclasses.dataclass
|
146 |
+
class LogisticEmbeddingModel:
|
147 |
+
"""A model that uses logistic regression with embeddings."""
|
148 |
+
|
149 |
+
_metrics: Optional[ConceptMetrics] = None
|
150 |
+
_threshold: float = 0.5
|
151 |
+
|
152 |
+
def __post_init__(self) -> None:
|
153 |
+
# See `notebooks/Toxicity.ipynb` for an example of training a concept model.
|
154 |
+
self._model = LogisticRegression(
|
155 |
+
class_weight=None, C=30, tol=1e-5, warm_start=True, max_iter=1_000, n_jobs=-1)
|
156 |
|
157 |
def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
|
158 |
"""Get the scores for the provided embeddings."""
|
159 |
try:
|
160 |
+
y_probs = self._model.predict_proba(embeddings)[:, 1]
|
161 |
+
# Map [0, threshold, 1] to [0, 0.5, 1].
|
162 |
+
interpolate_fn = interp1d([0, self._threshold, 1], [0, 0.4999, 1])
|
163 |
+
return interpolate_fn(y_probs)
|
164 |
except NotFittedError:
|
165 |
return np.random.rand(len(embeddings))
|
166 |
|
167 |
+
def _setup_training(
|
168 |
+
self, X_train: np.ndarray, y_train: list[bool],
|
169 |
+
implicit_negatives: Optional[np.ndarray]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
170 |
+
num_pos_labels = len([y for y in y_train if y])
|
171 |
+
num_neg_labels = len([y for y in y_train if not y])
|
172 |
+
sample_weights = [(1.0 / num_pos_labels if y else 1.0 / num_neg_labels) for y in y_train]
|
173 |
+
|
174 |
+
if implicit_negatives is not None:
|
175 |
+
num_implicit_labels = len(implicit_negatives)
|
176 |
+
implicit_labels = [False] * num_implicit_labels
|
177 |
+
X_train = np.concatenate([implicit_negatives, X_train])
|
178 |
+
y_train = np.concatenate([implicit_labels, y_train])
|
179 |
+
sample_weights = [1.0 / num_implicit_labels] * num_implicit_labels + sample_weights
|
180 |
+
|
181 |
+
# Normalize sample weights to sum to the number of training examples.
|
182 |
+
weights = np.array(sample_weights)
|
183 |
+
weights *= (X_train.shape[0] / np.sum(weights))
|
184 |
+
return X_train, np.array(y_train), weights
|
185 |
+
|
186 |
+
def fit(self, embeddings: np.ndarray, labels: list[bool],
|
187 |
+
implicit_negatives: Optional[np.ndarray]) -> None:
|
188 |
"""Fit the model to the provided embeddings and labels."""
|
189 |
+
label_set = set(labels)
|
190 |
+
if implicit_negatives is not None:
|
191 |
+
label_set.add(False)
|
192 |
+
if len(label_set) < 2:
|
193 |
return
|
194 |
if len(labels) != len(embeddings):
|
195 |
raise ValueError(
|
196 |
f'Length of embeddings ({len(embeddings)}) must match length of labels ({len(labels)})')
|
197 |
+
X_train, y_train, sample_weights = self._setup_training(embeddings, labels, implicit_negatives)
|
198 |
+
self._model.fit(X_train, y_train, sample_weights)
|
199 |
+
self._metrics, self._threshold = self._compute_metrics(embeddings, labels, implicit_negatives)
|
200 |
+
|
201 |
+
def _compute_metrics(
|
202 |
+
self, embeddings: np.ndarray, labels: list[bool],
|
203 |
+
implicit_negatives: Optional[np.ndarray]) -> tuple[Optional[ConceptMetrics], float]:
|
204 |
+
"""Return the concept metrics."""
|
205 |
+
labels = np.array(labels)
|
206 |
+
n_splits = min(len(labels), MAX_NUM_CROSS_VAL_MODELS)
|
207 |
+
fold = KFold(n_splits, shuffle=True, random_state=42)
|
208 |
+
|
209 |
+
def _fit_and_score(model: BaseEstimator, X_train: np.ndarray, y_train: np.ndarray,
|
210 |
+
sample_weights: np.ndarray, X_test: np.ndarray,
|
211 |
+
y_test: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
212 |
+
if len(set(y_train)) < 2:
|
213 |
+
return np.array([]), np.array([])
|
214 |
+
model.fit(X_train, y_train, sample_weights)
|
215 |
+
y_pred = model.predict_proba(X_test)[:, 1]
|
216 |
+
return y_test, y_pred
|
217 |
+
|
218 |
+
# Compute the metrics for each validation fold in parallel.
|
219 |
+
jobs: list[Callable] = []
|
220 |
+
for (train_index, test_index) in fold.split(embeddings):
|
221 |
+
X_train, y_train = embeddings[train_index], labels[train_index]
|
222 |
+
X_train, y_train, sample_weights = self._setup_training(X_train, y_train, implicit_negatives)
|
223 |
+
X_test, y_test = embeddings[test_index], labels[test_index]
|
224 |
+
model = clone(self._model)
|
225 |
+
jobs.append(delayed(_fit_and_score)(model, X_train, y_train, sample_weights, X_test, y_test))
|
226 |
+
results = Parallel(n_jobs=-1)(jobs)
|
227 |
+
|
228 |
+
y_test = np.concatenate([y_test for y_test, _ in results], axis=0)
|
229 |
+
y_pred = np.concatenate([y_pred for _, y_pred in results], axis=0)
|
230 |
+
if len(set(y_test)) < 2:
|
231 |
+
return None, 0.5
|
232 |
+
roc_auc_val = roc_auc_score(y_test, y_pred)
|
233 |
+
precision, recall, thresholds = precision_recall_curve(y_test, y_pred)
|
234 |
+
numerator = (1 + F_BETA_WEIGHT**2) * precision * recall
|
235 |
+
denom = (F_BETA_WEIGHT**2 * precision) + recall
|
236 |
+
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom != 0))
|
237 |
+
max_f1: float = np.max(f1_scores)
|
238 |
+
max_f1_index = np.argmax(f1_scores)
|
239 |
+
max_f1_thresh: float = thresholds[max_f1_index]
|
240 |
+
max_f1_prec: float = precision[max_f1_index]
|
241 |
+
max_f1_recall: float = recall[max_f1_index]
|
242 |
+
metrics = ConceptMetrics(
|
243 |
+
f1=max_f1,
|
244 |
+
precision=max_f1_prec,
|
245 |
+
recall=max_f1_recall,
|
246 |
+
roc_auc=roc_auc_val,
|
247 |
+
overall=_get_overall_score(max_f1))
|
248 |
+
return metrics, max_f1_thresh
|
249 |
|
250 |
|
251 |
def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
|
|
|
261 |
raise ValueError(
|
262 |
f'Draft {draft} not found in concept. Found drafts: {list(draft_examples.keys())}')
|
263 |
|
264 |
+
# Map the text of the draft to its id so we can dedupe with main.
|
265 |
draft_text_ids = {example.text: id for id, example in draft_examples[draft].items()}
|
266 |
|
267 |
# Write each of examples from main to the draft examples only if the text does not appear in the
|
|
|
273 |
return draft_examples[draft]
|
274 |
|
275 |
|
276 |
+
@dataclasses.dataclass
|
277 |
+
class ConceptModel:
|
278 |
"""A concept model. Stores all concept model drafts and manages syncing."""
|
279 |
# The concept that this model is for.
|
280 |
namespace: str
|
|
|
284 |
embedding_name: str
|
285 |
version: int = -1
|
286 |
|
287 |
+
column_info: Optional[ConceptColumnInfo] = None
|
288 |
+
|
289 |
+
# The following fields are excluded from JSON serialization, but still pickle-able.
|
290 |
# Maps a concept id to the embeddings.
|
291 |
+
_embeddings: dict[str, np.ndarray] = dataclasses.field(default_factory=dict)
|
292 |
+
_logistic_models: dict[DraftId, LogisticEmbeddingModel] = dataclasses.field(default_factory=dict)
|
293 |
_negative_vectors: Optional[np.ndarray] = None
|
294 |
|
295 |
+
def get_metrics(self, concept: Concept) -> Optional[ConceptMetrics]:
|
296 |
+
"""Return the metrics for this model."""
|
297 |
+
return self._get_logistic_model(DRAFT_MAIN)._metrics
|
298 |
+
|
299 |
+
def __post_init__(self) -> None:
|
300 |
+
if self.column_info:
|
301 |
+
self.column_info.path = normalize_path(self.column_info.path)
|
302 |
+
self._calibrate_on_dataset(self.column_info)
|
303 |
|
304 |
+
def _calibrate_on_dataset(self, column_info: ConceptColumnInfo) -> None:
|
305 |
"""Calibrate the model on the embeddings in the provided vector store."""
|
306 |
db = get_dataset(column_info.namespace, column_info.name)
|
307 |
+
vector_store = db.get_vector_store(self.embedding_name, normalize_path(column_info.path))
|
308 |
keys = vector_store.keys()
|
309 |
num_samples = min(column_info.num_negative_examples, len(keys))
|
310 |
sample_keys = random.sample(keys, num_samples)
|
|
|
332 |
def _get_logistic_model(self, draft: DraftId) -> LogisticEmbeddingModel:
|
333 |
"""Get the logistic model for the provided draft."""
|
334 |
if draft not in self._logistic_models:
|
335 |
+
self._logistic_models[draft] = LogisticEmbeddingModel()
|
|
|
|
|
|
|
|
|
336 |
return self._logistic_models[draft]
|
337 |
|
338 |
def sync(self, concept: Concept) -> bool:
|
|
|
351 |
examples = draft_examples(concept, draft)
|
352 |
embeddings = np.array([self._embeddings[id] for id in examples.keys()])
|
353 |
labels = [example.label for example in examples.values()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
model = self._get_logistic_model(draft)
|
355 |
with DebugTimer(f'Fitting model for "{concept_path}"'):
|
356 |
+
model.fit(embeddings, labels, self._negative_vectors)
|
|
|
|
|
|
|
357 |
|
358 |
# Synchronize the model version with the concept version.
|
359 |
self.version = concept.version
|
src/concepts/db_concept.py
CHANGED
@@ -2,15 +2,18 @@
|
|
2 |
|
3 |
import abc
|
4 |
import glob
|
|
|
5 |
import os
|
6 |
import pickle
|
7 |
import shutil
|
8 |
|
9 |
# NOTE: We have to import the module for uuid so it can be mocked.
|
10 |
import uuid
|
|
|
11 |
from typing import List, Optional, Union, cast
|
12 |
|
13 |
from pydantic import BaseModel
|
|
|
14 |
from typing_extensions import override
|
15 |
|
16 |
from ..config import data_path
|
@@ -27,6 +30,7 @@ from .concept import (
|
|
27 |
ExampleIn,
|
28 |
)
|
29 |
|
|
|
30 |
DATASET_CONCEPTS_DIR = '.concepts'
|
31 |
CONCEPT_JSON_FILENAME = 'concept.json'
|
32 |
|
@@ -65,8 +69,19 @@ class ConceptDB(abc.ABC):
|
|
65 |
pass
|
66 |
|
67 |
@abc.abstractmethod
|
68 |
-
def create(self,
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
pass
|
71 |
|
72 |
@abc.abstractmethod
|
@@ -115,7 +130,7 @@ class ConceptModelDB(abc.ABC):
|
|
115 |
pass
|
116 |
|
117 |
@abc.abstractmethod
|
118 |
-
def _save(self, model: ConceptModel
|
119 |
"""Save the concept model."""
|
120 |
pass
|
121 |
|
@@ -126,13 +141,14 @@ class ConceptModelDB(abc.ABC):
|
|
126 |
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
|
127 |
return concept.version == model.version
|
128 |
|
129 |
-
def sync(self, model: ConceptModel
|
130 |
"""Sync the concept model. Returns true if the model was updated."""
|
131 |
concept = self._concept_db.get(model.namespace, model.concept_name)
|
132 |
if not concept:
|
133 |
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
|
134 |
model_updated = model.sync(concept)
|
135 |
-
|
|
|
136 |
return model_updated
|
137 |
|
138 |
@abc.abstractmethod
|
@@ -149,6 +165,16 @@ class ConceptModelDB(abc.ABC):
|
|
149 |
"""Remove all the models associated with a concept."""
|
150 |
pass
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
class DiskConceptModelDB(ConceptModelDB):
|
154 |
"""Interface for the concept model database."""
|
@@ -191,10 +217,10 @@ class DiskConceptModelDB(ConceptModelDB):
|
|
191 |
with open_file(concept_model_path, 'rb') as f:
|
192 |
return pickle.load(f)
|
193 |
|
194 |
-
def _save(self, model: ConceptModel
|
195 |
"""Save the concept model."""
|
196 |
concept_model_path = _concept_model_path(model.namespace, model.concept_name,
|
197 |
-
model.embedding_name, column_info)
|
198 |
with open_file(concept_model_path, 'wb') as f:
|
199 |
pickle.dump(model, f)
|
200 |
|
@@ -224,10 +250,39 @@ class DiskConceptModelDB(ConceptModelDB):
|
|
224 |
for dir in dirs:
|
225 |
shutil.rmtree(dir, ignore_errors=True)
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
def _concept_output_dir(namespace: str, name: str) -> str:
|
229 |
"""Return the output directory for a given concept."""
|
230 |
-
return os.path.join(data_path(),
|
231 |
|
232 |
|
233 |
def _concept_json_path(namespace: str, name: str) -> str:
|
@@ -246,7 +301,7 @@ def _concept_model_path(namespace: str,
|
|
246 |
path_without_wildcards = (p for p in path_tuple if p != PATH_WILDCARD)
|
247 |
path_dir = os.path.join(dataset_dir, *path_without_wildcards)
|
248 |
return os.path.join(path_dir, DATASET_CONCEPTS_DIR, namespace, concept_name,
|
249 |
-
f'{embedding_name}.pkl')
|
250 |
|
251 |
|
252 |
class DiskConceptDB(ConceptDB):
|
@@ -280,16 +335,29 @@ class DiskConceptDB(ConceptDB):
|
|
280 |
return None
|
281 |
|
282 |
with open_file(concept_json_path) as f:
|
283 |
-
|
|
|
|
|
|
|
284 |
|
285 |
@override
|
286 |
-
def create(self,
|
|
|
|
|
|
|
|
|
287 |
"""Create a concept."""
|
288 |
concept_json_path = _concept_json_path(namespace, name)
|
289 |
if file_exists(concept_json_path):
|
290 |
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" already exists.')
|
291 |
|
292 |
-
concept = Concept(
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
self._save(concept)
|
294 |
return concept
|
295 |
|
|
|
2 |
|
3 |
import abc
|
4 |
import glob
|
5 |
+
import json
|
6 |
import os
|
7 |
import pickle
|
8 |
import shutil
|
9 |
|
10 |
# NOTE: We have to import the module for uuid so it can be mocked.
|
11 |
import uuid
|
12 |
+
from pathlib import Path
|
13 |
from typing import List, Optional, Union, cast
|
14 |
|
15 |
from pydantic import BaseModel
|
16 |
+
from pyparsing import Any
|
17 |
from typing_extensions import override
|
18 |
|
19 |
from ..config import data_path
|
|
|
30 |
ExampleIn,
|
31 |
)
|
32 |
|
33 |
+
CONCEPTS_DIR = 'concept'
|
34 |
DATASET_CONCEPTS_DIR = '.concepts'
|
35 |
CONCEPT_JSON_FILENAME = 'concept.json'
|
36 |
|
|
|
69 |
pass
|
70 |
|
71 |
@abc.abstractmethod
|
72 |
+
def create(self,
|
73 |
+
namespace: str,
|
74 |
+
name: str,
|
75 |
+
type: SignalInputType,
|
76 |
+
description: Optional[str] = None) -> Concept:
|
77 |
+
"""Create a concept.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
namespace: The namespace of the concept.
|
81 |
+
name: The name of the concept.
|
82 |
+
type: The input type of the concept.
|
83 |
+
description: The description of the concept.
|
84 |
+
"""
|
85 |
pass
|
86 |
|
87 |
@abc.abstractmethod
|
|
|
130 |
pass
|
131 |
|
132 |
@abc.abstractmethod
|
133 |
+
def _save(self, model: ConceptModel) -> None:
|
134 |
"""Save the concept model."""
|
135 |
pass
|
136 |
|
|
|
141 |
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
|
142 |
return concept.version == model.version
|
143 |
|
144 |
+
def sync(self, model: ConceptModel) -> bool:
|
145 |
"""Sync the concept model. Returns true if the model was updated."""
|
146 |
concept = self._concept_db.get(model.namespace, model.concept_name)
|
147 |
if not concept:
|
148 |
raise ValueError(f'Concept "{model.namespace}/{model.concept_name}" does not exist.')
|
149 |
model_updated = model.sync(concept)
|
150 |
+
if model_updated:
|
151 |
+
self._save(model)
|
152 |
return model_updated
|
153 |
|
154 |
@abc.abstractmethod
|
|
|
165 |
"""Remove all the models associated with a concept."""
|
166 |
pass
|
167 |
|
168 |
+
@abc.abstractmethod
|
169 |
+
def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
|
170 |
+
"""List all the models associated with a concept."""
|
171 |
+
pass
|
172 |
+
|
173 |
+
@abc.abstractmethod
|
174 |
+
def get_column_infos(self, namespace: str, concept_name: str) -> list[ConceptColumnInfo]:
|
175 |
+
"""Get the dataset columns where this concept was applied to."""
|
176 |
+
pass
|
177 |
+
|
178 |
|
179 |
class DiskConceptModelDB(ConceptModelDB):
|
180 |
"""Interface for the concept model database."""
|
|
|
217 |
with open_file(concept_model_path, 'rb') as f:
|
218 |
return pickle.load(f)
|
219 |
|
220 |
+
def _save(self, model: ConceptModel) -> None:
|
221 |
"""Save the concept model."""
|
222 |
concept_model_path = _concept_model_path(model.namespace, model.concept_name,
|
223 |
+
model.embedding_name, model.column_info)
|
224 |
with open_file(concept_model_path, 'wb') as f:
|
225 |
pickle.dump(model, f)
|
226 |
|
|
|
250 |
for dir in dirs:
|
251 |
shutil.rmtree(dir, ignore_errors=True)
|
252 |
|
253 |
+
@override
|
254 |
+
def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
|
255 |
+
"""List all the models associated with a concept."""
|
256 |
+
model_files = glob.iglob(os.path.join(_concept_output_dir(namespace, concept_name), '*.pkl'))
|
257 |
+
models: list[ConceptModel] = []
|
258 |
+
for model_file in model_files:
|
259 |
+
embedding_name = os.path.basename(model_file)[:-len('.pkl')]
|
260 |
+
model = self.get(namespace, concept_name, embedding_name)
|
261 |
+
if model:
|
262 |
+
models.append(model)
|
263 |
+
return models
|
264 |
+
|
265 |
+
@override
|
266 |
+
def get_column_infos(self, namespace: str, concept_name: str) -> list[ConceptColumnInfo]:
|
267 |
+
datasets_path = os.path.join(data_path(), DATASETS_DIR_NAME)
|
268 |
+
# Skip if 'datasets' doesn't exist.
|
269 |
+
if not os.path.isdir(datasets_path):
|
270 |
+
return []
|
271 |
+
|
272 |
+
dirs = glob.iglob(
|
273 |
+
os.path.join(datasets_path, '**', DATASET_CONCEPTS_DIR, namespace, concept_name),
|
274 |
+
recursive=True)
|
275 |
+
result: list[ConceptColumnInfo] = []
|
276 |
+
for dir in dirs:
|
277 |
+
dir = os.path.relpath(dir, datasets_path)
|
278 |
+
dataset_namespace, dataset_name, *path, _, _, _ = Path(dir).parts
|
279 |
+
result.append(ConceptColumnInfo(namespace=dataset_namespace, name=dataset_name, path=path))
|
280 |
+
return result
|
281 |
+
|
282 |
|
283 |
def _concept_output_dir(namespace: str, name: str) -> str:
|
284 |
"""Return the output directory for a given concept."""
|
285 |
+
return os.path.join(data_path(), CONCEPTS_DIR, namespace, name)
|
286 |
|
287 |
|
288 |
def _concept_json_path(namespace: str, name: str) -> str:
|
|
|
301 |
path_without_wildcards = (p for p in path_tuple if p != PATH_WILDCARD)
|
302 |
path_dir = os.path.join(dataset_dir, *path_without_wildcards)
|
303 |
return os.path.join(path_dir, DATASET_CONCEPTS_DIR, namespace, concept_name,
|
304 |
+
f'{embedding_name}-neg-{column_info.num_negative_examples}.pkl')
|
305 |
|
306 |
|
307 |
class DiskConceptDB(ConceptDB):
|
|
|
335 |
return None
|
336 |
|
337 |
with open_file(concept_json_path) as f:
|
338 |
+
obj: dict[str, Any] = json.load(f)
|
339 |
+
if 'namespace' not in obj:
|
340 |
+
obj['namespace'] = namespace
|
341 |
+
return Concept.parse_obj(obj)
|
342 |
|
343 |
@override
|
344 |
+
def create(self,
|
345 |
+
namespace: str,
|
346 |
+
name: str,
|
347 |
+
type: SignalInputType,
|
348 |
+
description: Optional[str] = None) -> Concept:
|
349 |
"""Create a concept."""
|
350 |
concept_json_path = _concept_json_path(namespace, name)
|
351 |
if file_exists(concept_json_path):
|
352 |
raise ValueError(f'Concept with namespace "{namespace}" and name "{name}" already exists.')
|
353 |
|
354 |
+
concept = Concept(
|
355 |
+
namespace=namespace,
|
356 |
+
concept_name=name,
|
357 |
+
type=type,
|
358 |
+
data={},
|
359 |
+
version=0,
|
360 |
+
description=description)
|
361 |
self._save(concept)
|
362 |
return concept
|
363 |
|
src/concepts/db_concept_test.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
"""Tests for the the database concept."""
|
2 |
|
3 |
from pathlib import Path
|
4 |
-
from typing import Generator, Iterable, Type, cast
|
5 |
|
6 |
import numpy as np
|
7 |
import pytest
|
@@ -423,7 +423,8 @@ class TestLogisticModel(LogisticEmbeddingModel):
|
|
423 |
return np.array([.1])
|
424 |
|
425 |
@override
|
426 |
-
def fit(self, embeddings: np.ndarray, labels: list[bool],
|
|
|
427 |
pass
|
428 |
|
429 |
|
@@ -436,20 +437,24 @@ class ConceptModelDBSuite:
|
|
436 |
concept_db = concept_db_cls()
|
437 |
model_db = model_db_cls(concept_db)
|
438 |
model = _make_test_concept_model(concept_db)
|
439 |
-
model_db.sync(model
|
440 |
retrieved_model = model_db.get(
|
441 |
namespace='test', concept_name='test_concept', embedding_name='test_embedding')
|
442 |
if not retrieved_model:
|
443 |
retrieved_model = model_db.create(
|
444 |
namespace='test', concept_name='test_concept', embedding_name='test_embedding')
|
445 |
-
assert retrieved_model == model
|
|
|
|
|
|
|
|
|
446 |
|
447 |
def test_sync_model(self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB],
|
448 |
mocker: MockerFixture) -> None:
|
449 |
|
450 |
concept_db = concept_db_cls()
|
451 |
model_db = model_db_cls(concept_db)
|
452 |
-
logistic_model = TestLogisticModel(
|
453 |
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
454 |
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
455 |
|
@@ -459,7 +464,7 @@ class ConceptModelDBSuite:
|
|
459 |
assert score_embeddings_mock.call_count == 0
|
460 |
assert fit_mock.call_count == 0
|
461 |
|
462 |
-
model_db.sync(model
|
463 |
|
464 |
assert model_db.in_sync(model) is True
|
465 |
assert score_embeddings_mock.call_count == 0
|
@@ -471,20 +476,20 @@ class ConceptModelDBSuite:
|
|
471 |
model_db = model_db_cls(concept_db)
|
472 |
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
473 |
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
474 |
-
logistic_model = TestLogisticModel(
|
475 |
model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
|
476 |
-
model_db.sync(model
|
477 |
assert model_db.in_sync(model) is True
|
478 |
assert score_embeddings_mock.call_count == 0
|
479 |
assert fit_mock.call_count == 1
|
480 |
|
481 |
(called_model, called_embeddings, called_labels,
|
482 |
-
|
483 |
assert called_model == logistic_model
|
484 |
np.testing.assert_array_equal(
|
485 |
called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
|
486 |
assert called_labels == [False, True]
|
487 |
-
assert
|
488 |
|
489 |
# Edit the concept.
|
490 |
concept_db.edit('test', 'test_concept',
|
@@ -495,13 +500,13 @@ class ConceptModelDBSuite:
|
|
495 |
assert score_embeddings_mock.call_count == 0
|
496 |
assert fit_mock.call_count == 1
|
497 |
|
498 |
-
model_db.sync(model
|
499 |
assert model_db.in_sync(model) is True
|
500 |
assert score_embeddings_mock.call_count == 0
|
501 |
assert fit_mock.call_count == 2
|
502 |
# Fit is called again with new points on main only.
|
503 |
(called_model, called_embeddings, called_labels,
|
504 |
-
|
505 |
assert called_model == logistic_model
|
506 |
np.testing.assert_array_equal(
|
507 |
called_embeddings,
|
@@ -510,7 +515,7 @@ class ConceptModelDBSuite:
|
|
510 |
EMBEDDING_MAP['a new data point']
|
511 |
]))
|
512 |
assert called_labels == [False, True, False]
|
513 |
-
assert
|
514 |
|
515 |
def test_out_of_sync_draft_model(self, concept_db_cls: Type[ConceptDB],
|
516 |
model_db_cls: Type[ConceptModelDB],
|
@@ -519,14 +524,14 @@ class ConceptModelDBSuite:
|
|
519 |
model_db = model_db_cls(concept_db)
|
520 |
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
521 |
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
522 |
-
|
523 |
-
draft_model = TestLogisticModel(
|
524 |
model = _make_test_concept_model(
|
525 |
concept_db, logistic_models={
|
526 |
-
DRAFT_MAIN:
|
527 |
'test_draft': draft_model
|
528 |
})
|
529 |
-
model_db.sync(model
|
530 |
assert model_db.in_sync(model) is True
|
531 |
assert score_embeddings_mock.call_count == 0
|
532 |
assert fit_mock.call_count == 1
|
@@ -547,15 +552,16 @@ class ConceptModelDBSuite:
|
|
547 |
assert score_embeddings_mock.call_count == 0
|
548 |
assert fit_mock.call_count == 1
|
549 |
|
550 |
-
model_db.sync(model
|
551 |
assert model_db.in_sync(model) is True
|
552 |
assert score_embeddings_mock.call_count == 0
|
553 |
assert fit_mock.call_count == 3 # Fit is called on both the draft, and main.
|
554 |
|
555 |
# Fit is called again with the same points.
|
556 |
-
((called_model, called_embeddings, called_labels,
|
557 |
-
(called_draft_model, called_draft_embeddings, called_draft_labels,
|
558 |
-
|
|
|
559 |
|
560 |
# The draft model is called with the data from main, and the data from draft.
|
561 |
assert called_draft_model == draft_model
|
@@ -572,21 +578,21 @@ class ConceptModelDBSuite:
|
|
572 |
False,
|
573 |
False
|
574 |
]
|
575 |
-
assert
|
576 |
|
577 |
# The main model was fit without the data from the draft.
|
578 |
-
assert called_model ==
|
579 |
np.testing.assert_array_equal(
|
580 |
called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
|
581 |
assert called_labels == [False, True]
|
582 |
-
assert
|
583 |
|
584 |
def test_embedding_not_found_in_map(self, concept_db_cls: Type[ConceptDB],
|
585 |
model_db_cls: Type[ConceptModelDB]) -> None:
|
586 |
concept_db = concept_db_cls()
|
587 |
model_db = model_db_cls(concept_db)
|
588 |
model = _make_test_concept_model(concept_db)
|
589 |
-
model_db.sync(model
|
590 |
|
591 |
# Edit the concept.
|
592 |
concept_db.edit('test', 'test_concept',
|
@@ -596,5 +602,5 @@ class ConceptModelDBSuite:
|
|
596 |
assert model_db.in_sync(model) is False
|
597 |
|
598 |
with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
|
599 |
-
model_db.sync(model
|
600 |
-
model_db.sync(model
|
|
|
1 |
"""Tests for the the database concept."""
|
2 |
|
3 |
from pathlib import Path
|
4 |
+
from typing import Generator, Iterable, Optional, Type, cast
|
5 |
|
6 |
import numpy as np
|
7 |
import pytest
|
|
|
423 |
return np.array([.1])
|
424 |
|
425 |
@override
|
426 |
+
def fit(self, embeddings: np.ndarray, labels: list[bool],
|
427 |
+
implicit_negatives: Optional[np.ndarray]) -> None:
|
428 |
pass
|
429 |
|
430 |
|
|
|
437 |
concept_db = concept_db_cls()
|
438 |
model_db = model_db_cls(concept_db)
|
439 |
model = _make_test_concept_model(concept_db)
|
440 |
+
model_db.sync(model)
|
441 |
retrieved_model = model_db.get(
|
442 |
namespace='test', concept_name='test_concept', embedding_name='test_embedding')
|
443 |
if not retrieved_model:
|
444 |
retrieved_model = model_db.create(
|
445 |
namespace='test', concept_name='test_concept', embedding_name='test_embedding')
|
446 |
+
assert retrieved_model.namespace == model.namespace
|
447 |
+
assert retrieved_model.concept_name == model.concept_name
|
448 |
+
assert retrieved_model.embedding_name == model.embedding_name
|
449 |
+
assert retrieved_model.version == model.version
|
450 |
+
assert retrieved_model.column_info == model.column_info
|
451 |
|
452 |
def test_sync_model(self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB],
|
453 |
mocker: MockerFixture) -> None:
|
454 |
|
455 |
concept_db = concept_db_cls()
|
456 |
model_db = model_db_cls(concept_db)
|
457 |
+
logistic_model = TestLogisticModel()
|
458 |
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
459 |
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
460 |
|
|
|
464 |
assert score_embeddings_mock.call_count == 0
|
465 |
assert fit_mock.call_count == 0
|
466 |
|
467 |
+
model_db.sync(model)
|
468 |
|
469 |
assert model_db.in_sync(model) is True
|
470 |
assert score_embeddings_mock.call_count == 0
|
|
|
476 |
model_db = model_db_cls(concept_db)
|
477 |
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
478 |
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
479 |
+
logistic_model = TestLogisticModel()
|
480 |
model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
|
481 |
+
model_db.sync(model)
|
482 |
assert model_db.in_sync(model) is True
|
483 |
assert score_embeddings_mock.call_count == 0
|
484 |
assert fit_mock.call_count == 1
|
485 |
|
486 |
(called_model, called_embeddings, called_labels,
|
487 |
+
called_implicit_negatives) = fit_mock.call_args_list[-1].args
|
488 |
assert called_model == logistic_model
|
489 |
np.testing.assert_array_equal(
|
490 |
called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
|
491 |
assert called_labels == [False, True]
|
492 |
+
assert called_implicit_negatives is None
|
493 |
|
494 |
# Edit the concept.
|
495 |
concept_db.edit('test', 'test_concept',
|
|
|
500 |
assert score_embeddings_mock.call_count == 0
|
501 |
assert fit_mock.call_count == 1
|
502 |
|
503 |
+
model_db.sync(model)
|
504 |
assert model_db.in_sync(model) is True
|
505 |
assert score_embeddings_mock.call_count == 0
|
506 |
assert fit_mock.call_count == 2
|
507 |
# Fit is called again with new points on main only.
|
508 |
(called_model, called_embeddings, called_labels,
|
509 |
+
called_implicit_negatives) = fit_mock.call_args_list[-1].args
|
510 |
assert called_model == logistic_model
|
511 |
np.testing.assert_array_equal(
|
512 |
called_embeddings,
|
|
|
515 |
EMBEDDING_MAP['a new data point']
|
516 |
]))
|
517 |
assert called_labels == [False, True, False]
|
518 |
+
assert called_implicit_negatives is None
|
519 |
|
520 |
def test_out_of_sync_draft_model(self, concept_db_cls: Type[ConceptDB],
|
521 |
model_db_cls: Type[ConceptModelDB],
|
|
|
524 |
model_db = model_db_cls(concept_db)
|
525 |
score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
|
526 |
fit_mock = mocker.spy(TestLogisticModel, 'fit')
|
527 |
+
main_model = TestLogisticModel()
|
528 |
+
draft_model = TestLogisticModel()
|
529 |
model = _make_test_concept_model(
|
530 |
concept_db, logistic_models={
|
531 |
+
DRAFT_MAIN: main_model,
|
532 |
'test_draft': draft_model
|
533 |
})
|
534 |
+
model_db.sync(model)
|
535 |
assert model_db.in_sync(model) is True
|
536 |
assert score_embeddings_mock.call_count == 0
|
537 |
assert fit_mock.call_count == 1
|
|
|
552 |
assert score_embeddings_mock.call_count == 0
|
553 |
assert fit_mock.call_count == 1
|
554 |
|
555 |
+
model_db.sync(model)
|
556 |
assert model_db.in_sync(model) is True
|
557 |
assert score_embeddings_mock.call_count == 0
|
558 |
assert fit_mock.call_count == 3 # Fit is called on both the draft, and main.
|
559 |
|
560 |
# Fit is called again with the same points.
|
561 |
+
((called_model, called_embeddings, called_labels, called_implicit_negatives),
|
562 |
+
(called_draft_model, called_draft_embeddings, called_draft_labels,
|
563 |
+
called_draft_implicit_negatives)) = (
|
564 |
+
c.args for c in fit_mock.call_args_list[-2:])
|
565 |
|
566 |
# The draft model is called with the data from main, and the data from draft.
|
567 |
assert called_draft_model == draft_model
|
|
|
578 |
False,
|
579 |
False
|
580 |
]
|
581 |
+
assert called_draft_implicit_negatives is None
|
582 |
|
583 |
# The main model was fit without the data from the draft.
|
584 |
+
assert called_model == main_model
|
585 |
np.testing.assert_array_equal(
|
586 |
called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
|
587 |
assert called_labels == [False, True]
|
588 |
+
assert called_implicit_negatives is None
|
589 |
|
590 |
def test_embedding_not_found_in_map(self, concept_db_cls: Type[ConceptDB],
|
591 |
model_db_cls: Type[ConceptModelDB]) -> None:
|
592 |
concept_db = concept_db_cls()
|
593 |
model_db = model_db_cls(concept_db)
|
594 |
model = _make_test_concept_model(concept_db)
|
595 |
+
model_db.sync(model)
|
596 |
|
597 |
# Edit the concept.
|
598 |
concept_db.edit('test', 'test_concept',
|
|
|
602 |
assert model_db.in_sync(model) is False
|
603 |
|
604 |
with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
|
605 |
+
model_db.sync(model)
|
606 |
+
model_db.sync(model)
|
src/data/__pycache__/dataset.cpython-39.pyc
CHANGED
Binary files a/src/data/__pycache__/dataset.cpython-39.pyc and b/src/data/__pycache__/dataset.cpython-39.pyc differ
|
|
src/data/__pycache__/dataset_compute_signal_chain_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (9.78 kB). View file
|
|
src/data/__pycache__/dataset_compute_signal_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (20.7 kB). View file
|
|
src/data/__pycache__/dataset_duckdb.cpython-39.pyc
CHANGED
Binary files a/src/data/__pycache__/dataset_duckdb.cpython-39.pyc and b/src/data/__pycache__/dataset_duckdb.cpython-39.pyc differ
|
|
src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc
CHANGED
Binary files a/src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.3.1.pyc differ
|
|
src/data/__pycache__/dataset_select_groups_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (8.86 kB). View file
|
|
src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc
CHANGED
Binary files a/src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.3.1.pyc differ
|
|
src/data/__pycache__/dataset_select_rows_filter_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (6.8 kB). View file
|
|
src/data/__pycache__/dataset_select_rows_schema_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (17.3 kB). View file
|
|
src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc
CHANGED
Binary files a/src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.3.1.pyc differ
|
|
src/data/__pycache__/dataset_select_rows_search_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (11.8 kB). View file
|
|
src/data/__pycache__/dataset_select_rows_sort_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (20.8 kB). View file
|
|
src/data/__pycache__/dataset_select_rows_udf_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (16 kB). View file
|
|
src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc
CHANGED
Binary files a/src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc and b/src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.3.1.pyc differ
|
|
src/data/__pycache__/dataset_stats_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (5.66 kB). View file
|
|
src/data/__pycache__/dataset_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (21.8 kB). View file
|
|
src/data/__pycache__/dataset_utils.cpython-39.pyc
CHANGED
Binary files a/src/data/__pycache__/dataset_utils.cpython-39.pyc and b/src/data/__pycache__/dataset_utils.cpython-39.pyc differ
|
|
src/data/__pycache__/dataset_utils_test.cpython-39-pytest-7.4.0.pyc
ADDED
Binary file (4.22 kB). View file
|
|
src/data/__pycache__/duckdb_utils.cpython-39.pyc
CHANGED
Binary files a/src/data/__pycache__/duckdb_utils.cpython-39.pyc and b/src/data/__pycache__/duckdb_utils.cpython-39.pyc differ
|
|
src/data/dataset.py
CHANGED
@@ -255,7 +255,7 @@ class Dataset(abc.ABC):
|
|
255 |
pass
|
256 |
|
257 |
@abc.abstractmethod
|
258 |
-
def get_vector_store(self, path: PathTuple) -> VectorStore:
|
259 |
# TODO: Instead of this, allow selecting vectors via select_rows.
|
260 |
"""Get the vector store for a column."""
|
261 |
pass
|
|
|
255 |
pass
|
256 |
|
257 |
@abc.abstractmethod
|
258 |
+
def get_vector_store(self, embedding: str, path: PathTuple) -> VectorStore:
|
259 |
# TODO: Instead of this, allow selecting vectors via select_rows.
|
260 |
"""Get the vector store for a column."""
|
261 |
pass
|
src/data/dataset_duckdb.py
CHANGED
@@ -241,15 +241,23 @@ class DatasetDuckDB(Dataset):
|
|
241 |
raise NotImplementedError('count is not yet implemented for DuckDB.')
|
242 |
|
243 |
@override
|
244 |
-
def get_vector_store(self, path: PathTuple) -> VectorStore:
|
245 |
# Refresh the manifest to make sure we have the latest signal manifests.
|
246 |
self.manifest()
|
247 |
|
|
|
|
|
|
|
248 |
if path not in self._col_vector_stores:
|
249 |
-
|
250 |
-
m for m in self._signal_manifests
|
251 |
-
|
|
|
|
|
252 |
raise ValueError(f'No embedding found for path {path}.')
|
|
|
|
|
|
|
253 |
if not manifest.embedding_filename_prefix:
|
254 |
raise ValueError(f'Signal manifest for path {path} is not an embedding. '
|
255 |
f'Got signal manifest: {manifest}')
|
@@ -273,7 +281,7 @@ class DatasetDuckDB(Dataset):
|
|
273 |
manifest: DatasetManifest,
|
274 |
compute_dependencies: Optional[bool] = False,
|
275 |
task_step_id: Optional[TaskStepId] = None) -> tuple[PathTuple, Optional[TaskStepId]]:
|
276 |
-
"""Run all the signals
|
277 |
|
278 |
Args:
|
279 |
signal: The signal to prepare.
|
@@ -560,7 +568,8 @@ class DatasetDuckDB(Dataset):
|
|
560 |
if is_ordinal(leaf.dtype):
|
561 |
min_max_query = f"""
|
562 |
SELECT MIN(val) AS minVal, MAX(val) AS maxVal
|
563 |
-
FROM (SELECT {inner_select} as val FROM t)
|
|
|
564 |
"""
|
565 |
row = self._query(min_max_query)[0]
|
566 |
result.min_val, result.max_val = row
|
@@ -590,7 +599,9 @@ class DatasetDuckDB(Dataset):
|
|
590 |
named_bins = _normalize_bins(bins or leaf.bins)
|
591 |
stats = self.stats(leaf_path)
|
592 |
|
593 |
-
|
|
|
|
|
594 |
if named_bins is None:
|
595 |
# Auto-bin.
|
596 |
named_bins = _auto_bins(stats, NUM_AUTO_BINS)
|
@@ -606,11 +617,14 @@ class DatasetDuckDB(Dataset):
|
|
606 |
bin_index_col = 'col0'
|
607 |
bin_min_col = 'col1'
|
608 |
bin_max_col = 'col2'
|
609 |
-
|
|
|
|
|
610 |
outer_select = f"""(
|
611 |
SELECT {bin_index_col} FROM (
|
612 |
VALUES {', '.join(sql_bounds)}
|
613 |
-
) WHERE {
|
|
|
614 |
)"""
|
615 |
else:
|
616 |
if stats.approx_count_distinct >= dataset.TOO_MANY_DISTINCT:
|
@@ -625,6 +639,7 @@ class DatasetDuckDB(Dataset):
|
|
625 |
|
626 |
filters, _ = self._normalize_filters(filters, col_aliases={}, udf_aliases={}, manifest=manifest)
|
627 |
filter_queries = self._create_where(manifest, filters, searches=[])
|
|
|
628 |
where_query = ''
|
629 |
if filter_queries:
|
630 |
where_query = f"WHERE {' AND '.join(filter_queries)}"
|
@@ -756,8 +771,9 @@ class DatasetDuckDB(Dataset):
|
|
756 |
for udf_col in udf_columns:
|
757 |
if isinstance(udf_col.signal_udf, ConceptScoreSignal):
|
758 |
# Set dataset information on the signal.
|
|
|
759 |
udf_col.signal_udf.set_column_info(
|
760 |
-
ConceptColumnInfo(namespace=self.namespace, name=self.dataset_name, path=
|
761 |
|
762 |
# Decide on the exact sorting order.
|
763 |
sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
|
@@ -791,19 +807,19 @@ class DatasetDuckDB(Dataset):
|
|
791 |
if topk_udf_col:
|
792 |
key_prefixes: Optional[list[VectorKey]] = None
|
793 |
if where_query:
|
794 |
-
# If there are filters, we need to send UUIDs to the
|
795 |
df = con.execute(f'SELECT {UUID_COLUMN} FROM t {where_query}').df()
|
796 |
total_num_rows = len(df)
|
797 |
key_prefixes = df[UUID_COLUMN]
|
798 |
|
799 |
-
|
800 |
# The input is an embedding.
|
801 |
-
vector_store = self.get_vector_store(topk_udf_col.path)
|
802 |
k = (limit or 0) + (offset or 0)
|
803 |
-
topk =
|
804 |
topk_uuids = list(dict.fromkeys([cast(str, key[0]) for key, _ in topk]))
|
805 |
|
806 |
-
# Ignore all the other filters and filter DuckDB results only by the
|
807 |
uuid_filter = Filter(path=(UUID_COLUMN,), op=ListOp.IN, value=topk_uuids)
|
808 |
filter_query = self._create_where(manifest, [uuid_filter])[0]
|
809 |
where_query = f'WHERE {filter_query}'
|
@@ -923,7 +939,8 @@ class DatasetDuckDB(Dataset):
|
|
923 |
|
924 |
if signal.compute_type in [SignalInputType.TEXT_EMBEDDING]:
|
925 |
# The input is an embedding.
|
926 |
-
|
|
|
927 |
flat_keys = flatten_keys(df[UUID_COLUMN], input)
|
928 |
signal_out = signal.vector_compute(flat_keys, vector_store)
|
929 |
# Add progress.
|
@@ -1273,39 +1290,48 @@ class DatasetDuckDB(Dataset):
|
|
1273 |
binary_ops = set(BinaryOp)
|
1274 |
unary_ops = set(UnaryOp)
|
1275 |
list_ops = set(ListOp)
|
1276 |
-
for
|
1277 |
-
duckdb_path = self._leaf_path_to_duckdb_path(
|
1278 |
select_str = _select_sql(duckdb_path, flatten=True, unnest=False)
|
1279 |
-
is_array = any(subpath == PATH_WILDCARD for subpath in
|
|
|
|
|
|
|
|
|
1280 |
|
1281 |
-
if
|
1282 |
-
sql_op = BINARY_OP_TO_SQL[cast(BinaryOp,
|
1283 |
-
filter_val = cast(FeatureValue,
|
1284 |
if isinstance(filter_val, str):
|
1285 |
filter_val = f"'{filter_val}'"
|
1286 |
elif isinstance(filter_val, bytes):
|
1287 |
filter_val = _bytes_to_blob_literal(filter_val)
|
1288 |
else:
|
1289 |
filter_val = str(filter_val)
|
1290 |
-
|
1291 |
-
|
1292 |
-
|
1293 |
-
|
|
|
|
|
|
|
|
|
|
|
1294 |
filter_query = f'len({select_str}) > 0' if is_array else f'{select_str} IS NOT NULL'
|
1295 |
else:
|
1296 |
-
raise ValueError(f'Unary op: {
|
1297 |
-
elif
|
1298 |
-
if
|
1299 |
-
filter_list_val = cast(FeatureListValue,
|
1300 |
if not isinstance(filter_list_val, list):
|
1301 |
raise ValueError('filter with array value can only use the IN comparison')
|
1302 |
wrapped_filter_val = [f"'{part}'" for part in filter_list_val]
|
1303 |
filter_val = f'({", ".join(wrapped_filter_val)})'
|
1304 |
filter_query = f'{select_str} IN {filter_val}'
|
1305 |
else:
|
1306 |
-
raise ValueError(f'List op: {
|
1307 |
else:
|
1308 |
-
raise ValueError(f'Invalid filter op: {
|
1309 |
sql_filter_queries.append(filter_query)
|
1310 |
return sql_filter_queries
|
1311 |
|
@@ -1330,7 +1356,7 @@ class DatasetDuckDB(Dataset):
|
|
1330 |
return rows
|
1331 |
|
1332 |
def _query_df(self, query: str) -> pd.DataFrame:
|
1333 |
-
"""Execute a query that returns a
|
1334 |
result = self._execute(query)
|
1335 |
df = _replace_nan_with_none(result.df())
|
1336 |
result.close()
|
|
|
241 |
raise NotImplementedError('count is not yet implemented for DuckDB.')
|
242 |
|
243 |
@override
|
244 |
+
def get_vector_store(self, embedding: str, path: PathTuple) -> VectorStore:
|
245 |
# Refresh the manifest to make sure we have the latest signal manifests.
|
246 |
self.manifest()
|
247 |
|
248 |
+
if path[-1] != EMBEDDING_KEY:
|
249 |
+
path = (*path, embedding, PATH_WILDCARD, EMBEDDING_KEY)
|
250 |
+
|
251 |
if path not in self._col_vector_stores:
|
252 |
+
manifests = [
|
253 |
+
m for m in self._signal_manifests
|
254 |
+
if schema_contains_path(m.data_schema, path) and m.embedding_filename_prefix
|
255 |
+
]
|
256 |
+
if not manifests:
|
257 |
raise ValueError(f'No embedding found for path {path}.')
|
258 |
+
if len(manifests) > 1:
|
259 |
+
raise ValueError(f'Multiple embeddings found for path {path}. Got: {manifests}')
|
260 |
+
manifest = manifests[0]
|
261 |
if not manifest.embedding_filename_prefix:
|
262 |
raise ValueError(f'Signal manifest for path {path} is not an embedding. '
|
263 |
f'Got signal manifest: {manifest}')
|
|
|
281 |
manifest: DatasetManifest,
|
282 |
compute_dependencies: Optional[bool] = False,
|
283 |
task_step_id: Optional[TaskStepId] = None) -> tuple[PathTuple, Optional[TaskStepId]]:
|
284 |
+
"""Run all the signals dependencies required to run this signal.
|
285 |
|
286 |
Args:
|
287 |
signal: The signal to prepare.
|
|
|
568 |
if is_ordinal(leaf.dtype):
|
569 |
min_max_query = f"""
|
570 |
SELECT MIN(val) AS minVal, MAX(val) AS maxVal
|
571 |
+
FROM (SELECT {inner_select} as val FROM t)
|
572 |
+
WHERE NOT isnan(val)
|
573 |
"""
|
574 |
row = self._query(min_max_query)[0]
|
575 |
result.min_val, result.max_val = row
|
|
|
599 |
named_bins = _normalize_bins(bins or leaf.bins)
|
600 |
stats = self.stats(leaf_path)
|
601 |
|
602 |
+
leaf_is_float = is_float(leaf.dtype)
|
603 |
+
leaf_is_integer = is_integer(leaf.dtype)
|
604 |
+
if leaf_is_float or leaf_is_integer:
|
605 |
if named_bins is None:
|
606 |
# Auto-bin.
|
607 |
named_bins = _auto_bins(stats, NUM_AUTO_BINS)
|
|
|
617 |
bin_index_col = 'col0'
|
618 |
bin_min_col = 'col1'
|
619 |
bin_max_col = 'col2'
|
620 |
+
is_nan_filter = f'NOT isnan({inner_val}) AND' if leaf_is_float else ''
|
621 |
+
|
622 |
+
# We cast the field to `double` so binning works for both `float` and `int` fields.
|
623 |
outer_select = f"""(
|
624 |
SELECT {bin_index_col} FROM (
|
625 |
VALUES {', '.join(sql_bounds)}
|
626 |
+
) WHERE {is_nan_filter}
|
627 |
+
{inner_val}::DOUBLE >= {bin_min_col} AND {inner_val}::DOUBLE < {bin_max_col}
|
628 |
)"""
|
629 |
else:
|
630 |
if stats.approx_count_distinct >= dataset.TOO_MANY_DISTINCT:
|
|
|
639 |
|
640 |
filters, _ = self._normalize_filters(filters, col_aliases={}, udf_aliases={}, manifest=manifest)
|
641 |
filter_queries = self._create_where(manifest, filters, searches=[])
|
642 |
+
|
643 |
where_query = ''
|
644 |
if filter_queries:
|
645 |
where_query = f"WHERE {' AND '.join(filter_queries)}"
|
|
|
771 |
for udf_col in udf_columns:
|
772 |
if isinstance(udf_col.signal_udf, ConceptScoreSignal):
|
773 |
# Set dataset information on the signal.
|
774 |
+
source_path = udf_col.path if udf_col.path[-1] != EMBEDDING_KEY else udf_col.path[:-3]
|
775 |
udf_col.signal_udf.set_column_info(
|
776 |
+
ConceptColumnInfo(namespace=self.namespace, name=self.dataset_name, path=source_path))
|
777 |
|
778 |
# Decide on the exact sorting order.
|
779 |
sort_results = self._merge_sorts(search_udfs, sort_by, sort_order)
|
|
|
807 |
if topk_udf_col:
|
808 |
key_prefixes: Optional[list[VectorKey]] = None
|
809 |
if where_query:
|
810 |
+
# If there are filters, we need to send UUIDs to the top k query.
|
811 |
df = con.execute(f'SELECT {UUID_COLUMN} FROM t {where_query}').df()
|
812 |
total_num_rows = len(df)
|
813 |
key_prefixes = df[UUID_COLUMN]
|
814 |
|
815 |
+
topk_signal = cast(TextEmbeddingModelSignal, topk_udf_col.signal_udf)
|
816 |
# The input is an embedding.
|
817 |
+
vector_store = self.get_vector_store(topk_signal.embedding, topk_udf_col.path)
|
818 |
k = (limit or 0) + (offset or 0)
|
819 |
+
topk = topk_signal.vector_compute_topk(k, vector_store, key_prefixes)
|
820 |
topk_uuids = list(dict.fromkeys([cast(str, key[0]) for key, _ in topk]))
|
821 |
|
822 |
+
# Ignore all the other filters and filter DuckDB results only by the top k UUIDs.
|
823 |
uuid_filter = Filter(path=(UUID_COLUMN,), op=ListOp.IN, value=topk_uuids)
|
824 |
filter_query = self._create_where(manifest, [uuid_filter])[0]
|
825 |
where_query = f'WHERE {filter_query}'
|
|
|
939 |
|
940 |
if signal.compute_type in [SignalInputType.TEXT_EMBEDDING]:
|
941 |
# The input is an embedding.
|
942 |
+
embedding_signal = cast(TextEmbeddingModelSignal, signal)
|
943 |
+
vector_store = self.get_vector_store(embedding_signal.embedding, udf_col.path)
|
944 |
flat_keys = flatten_keys(df[UUID_COLUMN], input)
|
945 |
signal_out = signal.vector_compute(flat_keys, vector_store)
|
946 |
# Add progress.
|
|
|
1290 |
binary_ops = set(BinaryOp)
|
1291 |
unary_ops = set(UnaryOp)
|
1292 |
list_ops = set(ListOp)
|
1293 |
+
for f in filters:
|
1294 |
+
duckdb_path = self._leaf_path_to_duckdb_path(f.path, manifest.data_schema)
|
1295 |
select_str = _select_sql(duckdb_path, flatten=True, unnest=False)
|
1296 |
+
is_array = any(subpath == PATH_WILDCARD for subpath in f.path)
|
1297 |
+
|
1298 |
+
nan_filter = ''
|
1299 |
+
field = manifest.data_schema.get_field(f.path)
|
1300 |
+
filter_nans = field.dtype and is_float(field.dtype)
|
1301 |
|
1302 |
+
if f.op in binary_ops:
|
1303 |
+
sql_op = BINARY_OP_TO_SQL[cast(BinaryOp, f.op)]
|
1304 |
+
filter_val = cast(FeatureValue, f.value)
|
1305 |
if isinstance(filter_val, str):
|
1306 |
filter_val = f"'{filter_val}'"
|
1307 |
elif isinstance(filter_val, bytes):
|
1308 |
filter_val = _bytes_to_blob_literal(filter_val)
|
1309 |
else:
|
1310 |
filter_val = str(filter_val)
|
1311 |
+
if is_array:
|
1312 |
+
nan_filter = 'NOT isnan(x) AND' if filter_nans else ''
|
1313 |
+
filter_query = (f'len(list_filter({select_str}, '
|
1314 |
+
f'x -> {nan_filter} x {sql_op} {filter_val})) > 0')
|
1315 |
+
else:
|
1316 |
+
nan_filter = f'NOT isnan({select_str}) AND' if filter_nans else ''
|
1317 |
+
filter_query = f'{nan_filter} {select_str} {sql_op} {filter_val}'
|
1318 |
+
elif f.op in unary_ops:
|
1319 |
+
if f.op == UnaryOp.EXISTS:
|
1320 |
filter_query = f'len({select_str}) > 0' if is_array else f'{select_str} IS NOT NULL'
|
1321 |
else:
|
1322 |
+
raise ValueError(f'Unary op: {f.op} is not yet supported')
|
1323 |
+
elif f.op in list_ops:
|
1324 |
+
if f.op == ListOp.IN:
|
1325 |
+
filter_list_val = cast(FeatureListValue, f.value)
|
1326 |
if not isinstance(filter_list_val, list):
|
1327 |
raise ValueError('filter with array value can only use the IN comparison')
|
1328 |
wrapped_filter_val = [f"'{part}'" for part in filter_list_val]
|
1329 |
filter_val = f'({", ".join(wrapped_filter_val)})'
|
1330 |
filter_query = f'{select_str} IN {filter_val}'
|
1331 |
else:
|
1332 |
+
raise ValueError(f'List op: {f.op} is not yet supported')
|
1333 |
else:
|
1334 |
+
raise ValueError(f'Invalid filter op: {f.op}')
|
1335 |
sql_filter_queries.append(filter_query)
|
1336 |
return sql_filter_queries
|
1337 |
|
|
|
1356 |
return rows
|
1357 |
|
1358 |
def _query_df(self, query: str) -> pd.DataFrame:
|
1359 |
+
"""Execute a query that returns a data frame."""
|
1360 |
result = self._execute(query)
|
1361 |
df = _replace_nan_with_none(result.df())
|
1362 |
result.close()
|
src/data/dataset_select_groups_test.py
CHANGED
@@ -167,6 +167,8 @@ def test_named_bins(make_test_data: TestDataMaker) -> None:
|
|
167 |
'age': 80
|
168 |
}, {
|
169 |
'age': 55
|
|
|
|
|
170 |
}]
|
171 |
dataset = make_test_data(items)
|
172 |
|
@@ -178,7 +180,7 @@ def test_named_bins(make_test_data: TestDataMaker) -> None:
|
|
178 |
('middle-aged', 50, 65),
|
179 |
('senior', 65, None),
|
180 |
])
|
181 |
-
assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1)]
|
182 |
|
183 |
|
184 |
def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
|
@@ -192,11 +194,13 @@ def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
|
|
192 |
'age': 80
|
193 |
}, {
|
194 |
'age': 55
|
|
|
|
|
195 |
}]
|
196 |
data_schema = schema({
|
197 |
UUID_COLUMN: 'string',
|
198 |
'age': field(
|
199 |
-
'
|
200 |
bins=[
|
201 |
('young', None, 20),
|
202 |
('adult', 20, 50),
|
@@ -207,7 +211,7 @@ def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
|
|
207 |
dataset = make_test_data(items, data_schema)
|
208 |
|
209 |
result = dataset.select_groups(leaf_path='age')
|
210 |
-
assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1)]
|
211 |
|
212 |
|
213 |
def test_filters(make_test_data: TestDataMaker) -> None:
|
@@ -304,10 +308,10 @@ def test_too_many_distinct(make_test_data: TestDataMaker, mocker: MockerFixture)
|
|
304 |
|
305 |
|
306 |
def test_auto_bins_for_float(make_test_data: TestDataMaker) -> None:
|
307 |
-
items: list[Item] = [{'feature': float(i)} for i in range(5)]
|
308 |
dataset = make_test_data(items)
|
309 |
|
310 |
res = dataset.select_groups('feature')
|
311 |
-
assert res.counts == [('0', 1), ('3', 1), ('7', 1), ('11', 1), ('14', 1)]
|
312 |
assert res.too_many_distinct is False
|
313 |
assert res.bins
|
|
|
167 |
'age': 80
|
168 |
}, {
|
169 |
'age': 55
|
170 |
+
}, {
|
171 |
+
'age': float('nan')
|
172 |
}]
|
173 |
dataset = make_test_data(items)
|
174 |
|
|
|
180 |
('middle-aged', 50, 65),
|
181 |
('senior', 65, None),
|
182 |
])
|
183 |
+
assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
|
184 |
|
185 |
|
186 |
def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
|
|
|
194 |
'age': 80
|
195 |
}, {
|
196 |
'age': 55
|
197 |
+
}, {
|
198 |
+
'age': float('nan')
|
199 |
}]
|
200 |
data_schema = schema({
|
201 |
UUID_COLUMN: 'string',
|
202 |
'age': field(
|
203 |
+
'float32',
|
204 |
bins=[
|
205 |
('young', None, 20),
|
206 |
('adult', 20, 50),
|
|
|
211 |
dataset = make_test_data(items, data_schema)
|
212 |
|
213 |
result = dataset.select_groups(leaf_path='age')
|
214 |
+
assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
|
215 |
|
216 |
|
217 |
def test_filters(make_test_data: TestDataMaker) -> None:
|
|
|
308 |
|
309 |
|
310 |
def test_auto_bins_for_float(make_test_data: TestDataMaker) -> None:
|
311 |
+
items: list[Item] = [{'feature': float(i)} for i in range(5)] + [{'feature': float('nan')}]
|
312 |
dataset = make_test_data(items)
|
313 |
|
314 |
res = dataset.select_groups('feature')
|
315 |
+
assert res.counts == [('0', 1), ('3', 1), ('7', 1), ('11', 1), ('14', 1), (None, 1)]
|
316 |
assert res.too_many_distinct is False
|
317 |
assert res.bins
|
src/data/dataset_select_rows_filter_test.py
CHANGED
@@ -24,6 +24,9 @@ TEST_DATA: list[Item] = [{
|
|
24 |
'int': 2,
|
25 |
'bool': True,
|
26 |
'float': 1.0
|
|
|
|
|
|
|
27 |
}]
|
28 |
|
29 |
|
@@ -46,6 +49,91 @@ def test_filter_by_ids(make_test_data: TestDataMaker) -> None:
|
|
46 |
assert list(result) == []
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def test_filter_by_list_of_ids(make_test_data: TestDataMaker) -> None:
|
50 |
dataset = make_test_data(TEST_DATA)
|
51 |
|
|
|
24 |
'int': 2,
|
25 |
'bool': True,
|
26 |
'float': 1.0
|
27 |
+
}, {
|
28 |
+
UUID_COLUMN: '4',
|
29 |
+
'float': float('nan')
|
30 |
}]
|
31 |
|
32 |
|
|
|
49 |
assert list(result) == []
|
50 |
|
51 |
|
52 |
+
def test_filter_greater(make_test_data: TestDataMaker) -> None:
|
53 |
+
dataset = make_test_data(TEST_DATA)
|
54 |
+
|
55 |
+
id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER, 2.0)
|
56 |
+
result = dataset.select_rows(filters=[id_filter])
|
57 |
+
|
58 |
+
assert list(result) == [{UUID_COLUMN: '1', 'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}]
|
59 |
+
|
60 |
+
|
61 |
+
def test_filter_greater_equal(make_test_data: TestDataMaker) -> None:
|
62 |
+
dataset = make_test_data(TEST_DATA)
|
63 |
+
|
64 |
+
id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER_EQUAL, 2.0)
|
65 |
+
result = dataset.select_rows(filters=[id_filter])
|
66 |
+
|
67 |
+
assert list(result) == [{
|
68 |
+
UUID_COLUMN: '1',
|
69 |
+
'str': 'a',
|
70 |
+
'int': 1,
|
71 |
+
'bool': False,
|
72 |
+
'float': 3.0
|
73 |
+
}, {
|
74 |
+
UUID_COLUMN: '2',
|
75 |
+
'str': 'b',
|
76 |
+
'int': 2,
|
77 |
+
'bool': True,
|
78 |
+
'float': 2.0
|
79 |
+
}]
|
80 |
+
|
81 |
+
|
82 |
+
def test_filter_less(make_test_data: TestDataMaker) -> None:
|
83 |
+
dataset = make_test_data(TEST_DATA)
|
84 |
+
|
85 |
+
id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS, 2.0)
|
86 |
+
result = dataset.select_rows(filters=[id_filter])
|
87 |
+
|
88 |
+
assert list(result) == [{UUID_COLUMN: '3', 'str': 'b', 'int': 2, 'bool': True, 'float': 1.0}]
|
89 |
+
|
90 |
+
|
91 |
+
def test_filter_less_equal(make_test_data: TestDataMaker) -> None:
|
92 |
+
dataset = make_test_data(TEST_DATA)
|
93 |
+
|
94 |
+
id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS_EQUAL, 2.0)
|
95 |
+
result = dataset.select_rows(filters=[id_filter])
|
96 |
+
|
97 |
+
assert list(result) == [{
|
98 |
+
UUID_COLUMN: '2',
|
99 |
+
'str': 'b',
|
100 |
+
'int': 2,
|
101 |
+
'bool': True,
|
102 |
+
'float': 2.0
|
103 |
+
}, {
|
104 |
+
UUID_COLUMN: '3',
|
105 |
+
'str': 'b',
|
106 |
+
'int': 2,
|
107 |
+
'bool': True,
|
108 |
+
'float': 1.0
|
109 |
+
}]
|
110 |
+
|
111 |
+
|
112 |
+
def test_filter_not_equal(make_test_data: TestDataMaker) -> None:
|
113 |
+
dataset = make_test_data(TEST_DATA)
|
114 |
+
|
115 |
+
id_filter: BinaryFilterTuple = ('float', BinaryOp.NOT_EQUAL, 2.0)
|
116 |
+
result = dataset.select_rows(filters=[id_filter])
|
117 |
+
|
118 |
+
assert list(result) == [
|
119 |
+
{
|
120 |
+
UUID_COLUMN: '1',
|
121 |
+
'str': 'a',
|
122 |
+
'int': 1,
|
123 |
+
'bool': False,
|
124 |
+
'float': 3.0
|
125 |
+
},
|
126 |
+
{
|
127 |
+
UUID_COLUMN: '3',
|
128 |
+
'str': 'b',
|
129 |
+
'int': 2,
|
130 |
+
'bool': True,
|
131 |
+
'float': 1.0
|
132 |
+
},
|
133 |
+
# NaNs are not counted when we are filtering a field.
|
134 |
+
]
|
135 |
+
|
136 |
+
|
137 |
def test_filter_by_list_of_ids(make_test_data: TestDataMaker) -> None:
|
138 |
dataset = make_test_data(TEST_DATA)
|
139 |
|
src/data/dataset_select_rows_search_test.py
CHANGED
@@ -288,17 +288,9 @@ def test_concept_search(make_test_data: TestDataMaker, mocker: MockerFixture) ->
|
|
288 |
},
|
289 |
]
|
290 |
|
291 |
-
# Make sure fit was called with negative examples.
|
292 |
(_, embeddings, labels, _) = concept_model_mock.call_args_list[-1].args
|
293 |
-
assert embeddings.shape == (
|
294 |
assert labels == [
|
295 |
-
# Negative implicit labels.
|
296 |
-
False,
|
297 |
-
False,
|
298 |
-
False,
|
299 |
-
False,
|
300 |
-
False,
|
301 |
-
False,
|
302 |
# Explicit labels.
|
303 |
False,
|
304 |
True
|
|
|
288 |
},
|
289 |
]
|
290 |
|
|
|
291 |
(_, embeddings, labels, _) = concept_model_mock.call_args_list[-1].args
|
292 |
+
assert embeddings.shape == (2, 3)
|
293 |
assert labels == [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
# Explicit labels.
|
295 |
False,
|
296 |
True
|
src/data/dataset_stats_test.py
CHANGED
@@ -15,7 +15,7 @@ SIMPLE_ITEMS: list[Item] = [{
|
|
15 |
'str': 'a',
|
16 |
'int': 1,
|
17 |
'bool': False,
|
18 |
-
'float': 3.0
|
19 |
}, {
|
20 |
UUID_COLUMN: '2',
|
21 |
'str': 'b',
|
@@ -28,6 +28,9 @@ SIMPLE_ITEMS: list[Item] = [{
|
|
28 |
'int': 2,
|
29 |
'bool': True,
|
30 |
'float': 1.0
|
|
|
|
|
|
|
31 |
}]
|
32 |
|
33 |
|
@@ -40,7 +43,7 @@ def test_simple_stats(make_test_data: TestDataMaker) -> None:
|
|
40 |
|
41 |
result = dataset.stats(leaf_path='float')
|
42 |
assert result == StatsResult(
|
43 |
-
path=('float',), total_count=
|
44 |
|
45 |
result = dataset.stats(leaf_path='bool')
|
46 |
assert result == StatsResult(path=('bool',), total_count=3, approx_count_distinct=2)
|
|
|
15 |
'str': 'a',
|
16 |
'int': 1,
|
17 |
'bool': False,
|
18 |
+
'float': 3.0,
|
19 |
}, {
|
20 |
UUID_COLUMN: '2',
|
21 |
'str': 'b',
|
|
|
28 |
'int': 2,
|
29 |
'bool': True,
|
30 |
'float': 1.0
|
31 |
+
}, {
|
32 |
+
UUID_COLUMN: '4',
|
33 |
+
'float': float('nan')
|
34 |
}]
|
35 |
|
36 |
|
|
|
43 |
|
44 |
result = dataset.stats(leaf_path='float')
|
45 |
assert result == StatsResult(
|
46 |
+
path=('float',), total_count=4, approx_count_distinct=4, min_val=1.0, max_val=3.0)
|
47 |
|
48 |
result = dataset.stats(leaf_path='bool')
|
49 |
assert result == StatsResult(path=('bool',), total_count=3, approx_count_distinct=2)
|
src/data/dataset_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
"""Utilities for working with datasets."""
|
2 |
|
|
|
3 |
import math
|
4 |
import os
|
5 |
import pickle
|
@@ -283,7 +284,10 @@ def write_items_to_parquet(items: Iterable[Item], output_dir: str, schema: Schem
|
|
283 |
if UUID_COLUMN not in item:
|
284 |
item[UUID_COLUMN] = secrets.token_urlsafe(nbytes=12) # 16 base64 characters.
|
285 |
if os.getenv('DEBUG'):
|
286 |
-
|
|
|
|
|
|
|
287 |
writer.write(item)
|
288 |
num_items += 1
|
289 |
writer.close()
|
|
|
1 |
"""Utilities for working with datasets."""
|
2 |
|
3 |
+
import json
|
4 |
import math
|
5 |
import os
|
6 |
import pickle
|
|
|
284 |
if UUID_COLUMN not in item:
|
285 |
item[UUID_COLUMN] = secrets.token_urlsafe(nbytes=12) # 16 base64 characters.
|
286 |
if os.getenv('DEBUG'):
|
287 |
+
try:
|
288 |
+
_validate(item, arrow_schema)
|
289 |
+
except Exception as e:
|
290 |
+
raise ValueError(f'Error validating item: {json.dumps(item)}') from e
|
291 |
writer.write(item)
|
292 |
num_items += 1
|
293 |
writer.close()
|