nsthorat commited on
Commit
e9a1c18
1 Parent(s): c754234
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env +0 -6
  2. .gitattributes +0 -35
  3. .gitignore +1 -1
  4. Dockerfile +3 -0
  5. data/concept/lilac/profanity/concept.json +0 -0
  6. data/concept/lilac/profanity/sbert.pkl +0 -3
  7. data/concept/lilac/toxicity/cohere.pkl +0 -3
  8. data/concept/lilac/toxicity/concept.json +0 -0
  9. data/concept/lilac/toxicity/openai.pkl +0 -3
  10. data/concept/lilac/toxicity/sbert.pkl +0 -3
  11. data/concept/local/outerspace/cohere.pkl +0 -3
  12. data/concept/local/outerspace/concept.json +0 -188
  13. data/concept/local/outerspace/openai.pkl +0 -3
  14. data/concept/local/outerspace/sbert.pkl +0 -3
  15. data/datasets/local/spotify/data-00000-of-00001.parquet +0 -3
  16. data/datasets/local/spotify/manifest.json +0 -27
  17. data/datasets/local/spotify/text/.concepts/local/aliens/sbert-neg-100.pkl +0 -3
  18. data/datasets/local/spotify/text/.concepts/local/outer_space/sbert-neg-100.pkl +0 -3
  19. data/datasets/local/spotify/text/.concepts/local/outerspace/sbert-neg-100.pkl +0 -3
  20. data/datasets/local/spotify/text/.concepts/local/phone_addiction/sbert-neg-100.pkl +0 -3
  21. data/datasets/local/spotify/text/sbert/data-00000-of-00001.parquet +0 -3
  22. data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/data-00000-of-00001.parquet +0 -3
  23. data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/signal_manifest.json +0 -64
  24. data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.keys.pkl +0 -3
  25. data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.npy +0 -3
  26. data/datasets/local/spotify/text/sbert/signal_manifest.json +0 -37
  27. requirements.txt +1 -0
  28. src/concepts/concept.py +17 -8
  29. src/concepts/concept_test.py +0 -84
  30. src/concepts/db_concept_test.py +0 -606
  31. src/data/dataset_compute_signal_chain_test.py +0 -255
  32. src/data/dataset_compute_signal_test.py +0 -669
  33. src/data/dataset_duckdb.py +15 -13
  34. src/data/dataset_select_groups_test.py +0 -317
  35. src/data/dataset_select_rows_filter_test.py +0 -200
  36. src/data/dataset_select_rows_schema_test.py +0 -551
  37. src/data/dataset_select_rows_search_test.py +0 -393
  38. src/data/dataset_select_rows_sort_test.py +0 -904
  39. src/data/dataset_select_rows_udf_test.py +0 -404
  40. src/data/dataset_stats_test.py +0 -125
  41. src/data/dataset_test.py +0 -860
  42. src/data/dataset_utils.py +68 -34
  43. src/data/dataset_utils_test.py +0 -114
  44. src/data/sources/csv_source_test.py +0 -42
  45. src/data/sources/huggingface_source_test.py +0 -170
  46. src/data/sources/json_source_test.py +0 -74
  47. src/data/sources/pandas_source_test.py +0 -91
  48. src/data/sources/source_registry_test.py +0 -55
  49. src/data_loader_test.py +0 -74
  50. src/embeddings/embedding.py +18 -6
.env CHANGED
@@ -26,9 +26,3 @@ DUCKDB_USE_VIEWS=0
26
  # HF_USERNAME=
27
  # The default repo to deploy to for a staging demo. Can be overridden by a command line flag.
28
  # HF_STAGING_DEMO_REPO='HF_ORG/HF_REPO_NAME'
29
-
30
- # HuggingFace demos: HuggingFace machine that runs the demo.
31
-
32
- # To read private uploaded data from the server (running on HF spaces) for the demo.
33
- # Get a token from https://huggingface.co/settings/tokens
34
- # HF_ACCESS_TOKEN=
 
26
  # HF_USERNAME=
27
  # The default repo to deploy to for a staging demo. Can be overridden by a command line flag.
28
  # HF_STAGING_DEMO_REPO='HF_ORG/HF_REPO_NAME'
 
 
 
 
 
 
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,4 +1,4 @@
1
- **/__pycache__
2
  **/*.pyc
3
  **/*.pyo
4
  **/*.pyd
 
1
+ __pycache__/
2
  **/*.pyc
3
  **/*.pyo
4
  **/*.pyd
Dockerfile CHANGED
@@ -22,6 +22,9 @@ COPY /web/blueprint/build ./web/blueprint/build
22
  # Copy python files.
23
  COPY /src ./src/
24
 
 
 
 
25
  CMD [ \
26
  "gunicorn", "src.server:app", \
27
  "--bind", "0.0.0.0:5432", \
 
22
  # Copy python files.
23
  COPY /src ./src/
24
 
25
+ # Copy the data files. We use glob so docker copy won't fail if the directory doesn't exist.
26
+ COPY /dat[a] ./data/
27
+
28
  CMD [ \
29
  "gunicorn", "src.server:app", \
30
  "--bind", "0.0.0.0:5432", \
data/concept/lilac/profanity/concept.json DELETED
The diff for this file is too large to render. See raw diff
 
data/concept/lilac/profanity/sbert.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:647280d255e1a1fabff691683926fbb49dfaffe2f8151cf9913ec98816eef473
3
- size 844427
 
 
 
 
data/concept/lilac/toxicity/cohere.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:670e81b8448ab0ee5161a42b523410b3af80c6ccce8003cae78edebb9d0981c4
3
- size 9720631
 
 
 
 
data/concept/lilac/toxicity/concept.json DELETED
The diff for this file is too large to render. See raw diff
 
data/concept/lilac/toxicity/openai.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e15e8235c2152b1412a8e2dee3dcb94b23e95f1fde6fb60f01b876a832e46404
3
- size 3678199
 
 
 
 
data/concept/lilac/toxicity/sbert.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ac8b304760c88242eb6c567e1af87fd87731a192308df8cf43b253e24d2b0ec
3
- size 959111
 
 
 
 
data/concept/local/outerspace/cohere.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:30afc472c4beb1aabb24d5b8e633df6039ec532fd704d8318755e083592221f3
3
- size 331736
 
 
 
 
data/concept/local/outerspace/concept.json DELETED
@@ -1,188 +0,0 @@
1
- {
2
- "namespace": "local",
3
- "concept_name": "outerspace",
4
- "type": "text",
5
- "data": {
6
- "da77c67f82524ce1a276593471fd530f": {
7
- "label": true,
8
- "text": "Fly me to the moon and let me play among the stars.",
9
- "id": "da77c67f82524ce1a276593471fd530f"
10
- },
11
- "f73feff4be50410ab1ac468752d0301a": {
12
- "label": true,
13
- "text": "Space may be the final frontier but it's made in a Hollywood basement.",
14
- "id": "f73feff4be50410ab1ac468752d0301a"
15
- },
16
- "0f0815ed04604209842d9e7b1e3538f8": {
17
- "label": true,
18
- "text": "We're just a speck of dust within the galaxy.",
19
- "id": "0f0815ed04604209842d9e7b1e3538f8"
20
- },
21
- "2e41f637061e4ecb8b0d4e35abab9b63": {
22
- "label": true,
23
- "text": "In the darkest night, the stars shine bright and guide me to the moonlight.",
24
- "id": "2e41f637061e4ecb8b0d4e35abab9b63"
25
- },
26
- "fb65845f4dc84da1b276de30967592e3": {
27
- "label": true,
28
- "text": "We'll be shooting star through time and space\r\n\r\n",
29
- "id": "fb65845f4dc84da1b276de30967592e3"
30
- },
31
- "075534e3095b421687039291439b5524": {
32
- "label": true,
33
- "text": "Dreaming of love while cruising at high altitude \r\nDreaming of making love with you the way we should \r\nCloser to heaven. We're thirty thousand feet, up in the sky \r\nHere among the stars, our spirits will fly \r\n \r\nLeave all your worries as we soar over the clouds \r\nJet lag that's making you appear far from the crowd \r\nWhile we're suspended, locked in each others, sweet embrace \r",
34
- "id": "075534e3095b421687039291439b5524"
35
- },
36
- "4bb656032d0d4f449bac8aa5f23c3e48": {
37
- "label": true,
38
- "text": " \r\nI don't know where I don't know why \r\nBut somehow back in time again \r\nI'm on the edge that you can see \r\n \r\nI'm not particular at night \r\nA single party calling me \r\nYou won't be tracking me by sight \r\n \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r\nAt the speed of light \r\n \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r",
39
- "id": "4bb656032d0d4f449bac8aa5f23c3e48"
40
- },
41
- "4a6dda9001ea487991a1264e6a6c021b": {
42
- "label": true,
43
- "text": "Load redeem me, am I pure? \r\nAs pure as pure as heaven \r\nSent you money sent you flowers \r\nCould worship you for hours \r\nIn whose hands are we anyway? \r\n \r\nGo waiting for the stars \r\nTo come showering down \r\nFrom Moscow to Mars \r\nUniverse falling down \r\n \r\nYou got to look real hard \r\nIs it in your heart? \r\nYeah it's in there somewhere \r\nThe power wrapped in your palm \r",
44
- "id": "4a6dda9001ea487991a1264e6a6c021b"
45
- },
46
- "9aacce9311d24cb182aee783ca313c58": {
47
- "label": true,
48
- "text": "Growth is our future resource. \r\n \r\nJoin the state of the universe, \r\nUnited state of peace. \r\nJoin the state of the universe, \r\nUnited state of peace. \r\n \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r\nStarpeace, I see you growing, \r",
49
- "id": "9aacce9311d24cb182aee783ca313c58"
50
- },
51
- "313b8f9ce9164791b04ead82e6adb40f": {
52
- "label": false,
53
- "text": " \r\nEven I could see a light if it wasn't for the nights \r\n(Even I could see a light I think that I could make it) \r\nGuess my future would look bright if it wasn't for the nights\r\n\r\n",
54
- "id": "313b8f9ce9164791b04ead82e6adb40f"
55
- },
56
- "b9c587b74f084ef4917e7a52cd5c5cbe": {
57
- "label": true,
58
- "text": "Yea I think I know \r\nI really hate it when it gets too slow \r\nI gotta try and keep myself amused \r\nI love the way my rocket purrs \r\nI like it best when I see blurs \r\nYou gotta love to watch me light my fuse \r\n \r\nNo more lookin' back to yesterday \r\nI got the thing to blow us both away \r\nAll I need is you to navigate \r\nSo come and ride my Rocket 88 \r\n \r",
59
- "id": "b9c587b74f084ef4917e7a52cd5c5cbe"
60
- },
61
- "6f844600cc024117a22287557130a17b": {
62
- "label": false,
63
- "text": "They came flying from far away, now I'm under their spell \r\nI love hearing the stories that they tell \r\nThey've seen places beyond my land and they've found new horizons \r\nThey speak strangely but I understand \r\n \r\nAnd I dream I'm an eagle \r\nAnd I dream I can spread my wings \r\nFlying high, high, I'm a bird in the sky \r\nI'm an eagle that rides on the breeze \r",
64
- "id": "6f844600cc024117a22287557130a17b"
65
- },
66
- "8cddcff34f894743872ecc02262c2375": {
67
- "label": true,
68
- "text": "Fire! I can see it burning so brightly \r\nFire! I can feel it calling out to me \r\nAnd as the sun goes down \r\nIt starts to paint a picture \r\n \r\nOf an ancient town \r\nSo far away, across the endless sea \r\nLead me to the light \r\nAnd take me to the edge of heaven \r\n \r\nI'm standing in the night \r\nLooking for the edge of heaven \r\nWe'll be touching the edge of heaven \r\nTime \r\n \r",
69
- "id": "8cddcff34f894743872ecc02262c2375"
70
- },
71
- "3d044718f379452ab3c1e4d00c99f8f3": {
72
- "label": false,
73
- "text": "Fire! I can see it burning so brightly \r\nFire! I can feel it calling out to me \r\nAnd as the sun goes down \r\nIt starts to paint a picture \r\n \r\nOf an ancient town \r\nSo far away, across the endless sea \r\nLead me to the light \r\nAnd take me to the edge of heaven \r\n \r\nI'm standing in the night \r\nLooking for the edge of heaven \r\nWe'll be touching the edge of heaven \r\nTime \r\n \r",
74
- "id": "3d044718f379452ab3c1e4d00c99f8f3"
75
- },
76
- "d233250a91d44f13aac58eb5fa43afe6": {
77
- "label": true,
78
- "text": "Star \r\nWe go waiting for the stars \r\nTo come showering down \r\nFrom Moscow to Mars \r\nUniverse falling down \r\n \r\nYou got to look real hard \r\nThere's a fiery star \r\nHidden out there somewhere \r\nNot the satellite of love \r\nBut a laser \r\nShooting out it's shiny tongue there \r\n \r\nGod is love, God is war \r\nTV-preacher tell me more \r\nLoad redeem me, am I pure? \r",
79
- "id": "d233250a91d44f13aac58eb5fa43afe6"
80
- },
81
- "a30c9a5c63a2456f8f53a9177a522d7a": {
82
- "label": false,
83
- "text": "Tell me do you want to be free \r\n \r\nWell your love falls down you know \r\nAnd your heart might fall to pieces \r\nAnd I saw your soul get lost along the way \r\n \r\nAll these songs now they used to make you shine \r\nThey are just lullabies for your nightmares \r\nAnd Ill sing them softly now \r\n \r\nLately I've felt the warmth \r\nOf the one who tore down my walls \r\nBut then I look at you \r",
84
- "id": "a30c9a5c63a2456f8f53a9177a522d7a"
85
- },
86
- "89ce6961ff064f719212e68058bb2013": {
87
- "label": false,
88
- "text": "I Left Them Niggas Needin'Path \r\nAnd Ya'll Probly Won't Live To See This Weekend, \r\nGotta Go, Gotta Go, FUckin Mash Out \r\nI Hit The Dro' A Lil More And Then I Pass Out \r\nCrashin' The H2, Bitches I Hate You \r\nNow you Keep Talkin Shit, I Kidnap And Ducktape You \r\nLet Them Faggots Rape You \r\nThen It's Back To Mississippi, If Ya Boys Want Revenge \r\nTell Them Bitches Come And Get Me \r",
89
- "id": "89ce6961ff064f719212e68058bb2013"
90
- },
91
- "6de1b38adc9b4f48ac15609dad02faa0": {
92
- "label": true,
93
- "text": "In heaven's eye \r\n \r\nYes, this is our star. \r\nYes, this is our star. \r\nOur star our star.\r\n\r\n",
94
- "id": "6de1b38adc9b4f48ac15609dad02faa0"
95
- },
96
- "52ccd98280b849f498d838b6230285a7": {
97
- "label": false,
98
- "text": "Tell Them Bitches Come And Get Me \r\n'cause I Was Born In This Bitch To Die \r\nI'm In Queens, In Your 'Lac, With Your Bitch, Gettin' High \r\n \r\nYoung Buck: \r\nGold Grills, Coupe' Devilles Sittin On 22's \r\nThe Dirty, Dirty Baby \r\nShow 'Em How The South Do \r\nWe Pop Pills, Shoot To Kill, You Know What We 'Bout \r\nAnd On Behalf Of G-Unit, Welcome To The South \r\n \r\nLil Flip: \r",
99
- "id": "52ccd98280b849f498d838b6230285a7"
100
- },
101
- "866a61ec0ab04a54ade2532b7825c858": {
102
- "label": false,
103
- "text": "I Swear On The Soul's Of Our Dead Cousin's \r\nI Ain't Fuckin, Man I'm Commin Ak 40's Bustin', \r\n7's And Mack 11's \r\nI Told 'Em All I Ain't No Hoe \r\nBut Niggas Don't Listen Till You Kick A Nigga, \r\nSmack Him With That Callico \r\nI'm Tryin To Stay In Gods Plan \r\nBut I Hadta Show These Faggots That Your Fuckin With A Man, Ya Bitch! \r\nI Left Them Niggas Needin'Path \r",
104
- "id": "866a61ec0ab04a54ade2532b7825c858"
105
- },
106
- "0a2dbf3ee6cd46ae9f71ecb65e02674e": {
107
- "label": true,
108
- "text": "And filling up the space \r\nMen and women boys and girls \r\nThere are so many people in the world \r\nThinkin' about the world \r\nAnd all the people in it \r\nAnd I'm staring at the stars \r\nAnd into the infinite \r\nIn a world without a world \r\nOn a planet that's \r\nDriftin' in a space \r\n \r\nSeconds into minutes and minutes \r\nInto hour and hours into days \r",
109
- "id": "0a2dbf3ee6cd46ae9f71ecb65e02674e"
110
- },
111
- "fff7748b4c384cb49ae18f96df719aa8": {
112
- "label": false,
113
- "text": "And the way things ought to be \r\n \r\nWhat kind of difference \r\nCan on person make? \r\nCut to the chase\r\n\r\n",
114
- "id": "fff7748b4c384cb49ae18f96df719aa8"
115
- },
116
- "54971cdd9be0444096cacd2637a50ce4": {
117
- "label": false,
118
- "text": "With bar lights and pretty girls \r\nBut most nights I stay straight and think about my mom \r\nOh god, I miss her so much \r\n \r\nAnd there are people on the street \r\nThey're coming up to me \r\nThey're telling me that they like what I do now \r\nAnd so I tried my best when I took the fall \r\nTo get right back up, back in your arms \r\nIf you're out here why do I miss you so much \r\n \r",
119
- "id": "54971cdd9be0444096cacd2637a50ce4"
120
- },
121
- "048e4f04661d4f71a48d48f216b30975": {
122
- "label": true,
123
- "text": " \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r\nAt the speed of light \r\n \r\nShadows and the stars \r\nWe will not return \r\nHumanity won't save us \r\nWe slip into the night \r\n \r\nI'll say a mass for you and wave \r\nShooting plasma from my grave \r\n \r\nEvent horizon lost in space \r\nRunning in a human race \r\n \r\nI don't know where I don't know why \r",
124
- "id": "048e4f04661d4f71a48d48f216b30975"
125
- },
126
- "f4ee9e97357c4f2fa0ed627a6983e4de": {
127
- "label": false,
128
- "text": "I am here to tell you we can never meet again \r\nSimple really, isn't it, a word or two and then \r\nA lifetime of not knowing where or how or why or when \r\nYou think of me or speak of me or wonder what befell \r\nThe someone you once loved so long ago so well \r\n \r\nNever wonder what I'll feel as living shuffles by \r\nYou don't have to ask me and I need not reply \r",
129
- "id": "f4ee9e97357c4f2fa0ed627a6983e4de"
130
- },
131
- "797514b7375f4ef8bfbd3320936b266a": {
132
- "label": false,
133
- "text": " \r\nThe last time that I saw him he was trying hard to get \r\nA woman's education but he's not a woman yet \r\nAnd the last time that I saw her she was living with some boy \r\nWho gives her soul an empty room and gives her body joy. \r\n \r\nSo the great affair is over but whoever would have guessed \r\nIt would leave us all so vacant and so deeply unimpressed \r",
134
- "id": "797514b7375f4ef8bfbd3320936b266a"
135
- },
136
- "56663fdf792a4820b7ae2e4344542cfa": {
137
- "label": true,
138
- "text": "Yeah we'll find our star \r\nBut maybe that's another world \r\n \r\nFar away from where we are \r\nYeah we'll find our star \r\nBut maybe that's another world\r\n\r\n",
139
- "id": "56663fdf792a4820b7ae2e4344542cfa"
140
- },
141
- "d522d97e7d44430e945e40720d54e98d": {
142
- "label": false,
143
- "text": "The silly people just like you and better too. \r\nHow can you keep turning when the overture is burning in the faces \r\nOf the people in the churches of the land. \r\n \r\nThat's all it seems, there is only one dream. \r\nThe day has come at last.\r\n\r\n",
144
- "id": "d522d97e7d44430e945e40720d54e98d"
145
- },
146
- "761a17d5909d4c7c9cd0cd1ac8c2db76": {
147
- "label": false,
148
- "text": "Ah the man she wanted all her life was hanging by a thread \r\n\"I never even knew how much I wanted you,\" she said. \r\nHis muscles they were numbered and his style was obsolete. \r\n\"O baby, I have come too late.\" She knelt beside his feet. \r\n\"I'll never see a face like yours in years of men to come \r\nI'll never see such arms again in wrestling or in love.\" \r",
149
- "id": "761a17d5909d4c7c9cd0cd1ac8c2db76"
150
- },
151
- "ffc68f626c7d41be8661babedf589778": {
152
- "label": true,
153
- "text": "let us make computations of the stars. \r\n \r\nOlder, wiser, sadder, blinder, watch us run: \r\nfaster, longer, harder, stronger, now it comes: \r\ncolour blisters, image splinters gravitate \r\ntowards the centre, in final splendour disintegrate, \r\nThe universe now beckons \r\nand Man, too, must take His place... \r\njust a few last fleeting seconds \r\nto wander in the waste, \r",
154
- "id": "ffc68f626c7d41be8661babedf589778"
155
- },
156
- "8e8ffd440c2f48ebb5ae04810be5d090": {
157
- "label": false,
158
- "text": "And boy you'll see \r\nIt's an illusion shining down in front of me, \r\n \r\nAnd then you'll say \r\nEven in time we shall control the day, \r\nWhen what you'll see \r\nDeep inside base controlling you and me. \r\n \r\nAnd one peculiar point I see, \r\nAs one of many ones of me. \r\nAs truth is gathered, I rearrange, \r\nInside out, outside in, inside out, outside in, \r\nPerpetual change. \r\n \r",
159
- "id": "8e8ffd440c2f48ebb5ae04810be5d090"
160
- },
161
- "c61414a653bb4a9482f341dbfbea4a47": {
162
- "label": false,
163
- "text": "While there's still time to choose \r\n \r\nEvery day of my life I discover \r\nSomeone murdering my sisters and brothers \r\nIn the name of some god or another \r\nWhat do you know \r\n \r\nFor the first precious few it's time to go \r\nWhat might have been we'll never know \r\nAll those bad ideas became the law \r\nOh yes, we've forgotten what we're looking for \r",
164
- "id": "c61414a653bb4a9482f341dbfbea4a47"
165
- },
166
- "3a325e1d3789416584ad836e2d32df05": {
167
- "label": true,
168
- "text": "Earth is the third planet from the Sun and the only place known in the universe where life has originated and found habitability. This is enabled by Earth being a water world, the only one in the Solar System sustaining liquid surface water. Almost all of Earth's water is contained in its global ocean, spanning 70.8% of Earth's surface. The other 29.2% are spanned by land, consisting of continents",
169
- "id": "3a325e1d3789416584ad836e2d32df05"
170
- },
171
- "44e9840483164b6b97e06f909e25b8dc": {
172
- "label": false,
173
- "text": "Human geography\nToggle Human geography subsection\nCultural and historical viewpoint\nSee also\nNotes\nReferences\nExternal links\nEarth",
174
- "id": "44e9840483164b6b97e06f909e25b8dc"
175
- },
176
- "bcf625326bc64c6ca6d37fb59bffa5ba": {
177
- "label": true,
178
- "text": "When the ebbing tide retreats along the rocky shoreline\nIt leaves a trail of tide pools in a short-lived galaxy\nEach microcosmic planet, a complete society\nA simple kind of mirror to reflect upon our own\nAll the busy little creatures chasing out their destinies\nLiving in their pools, they soon forget about the sea\nWheel within wheels in a spiral array\nA pattern so grand and complex",
179
- "id": "bcf625326bc64c6ca6d37fb59bffa5ba"
180
- },
181
- "7c2be4b17d8f49069f6179c5256acc5e": {
182
- "label": true,
183
- "text": "Beneath my dreams and wishes \nI long for thy caresses. \n \nA bridal bed awaits us both, \nAfter the landscape of death I cross. \nBefore my sorrows I must die, \nNightwish I send through the starlit sky. \n \n\"Passed away in silence \nThe flute from the realm unseen \nEmpties it's heart \nMaking love to me \nWith it's enchanting melody. \nLight of Orion, \nShadow of Andromeda, ",
184
- "id": "7c2be4b17d8f49069f6179c5256acc5e"
185
- }
186
- },
187
- "version": 34
188
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/concept/local/outerspace/openai.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ea2acd96a43d1c678273e7ec297b1758a3d09d1137f0325ac3058ca9a110112
3
- size 126895
 
 
 
 
data/concept/local/outerspace/sbert.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9916794dbe5526af5103019735188b637f9975a5326a21713380058034e13525
3
- size 34935
 
 
 
 
data/datasets/local/spotify/data-00000-of-00001.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:32224657332b09187a737c73ab634f9d14c9ba9a240bd105f1b9819cde2afcef
3
- size 37128682
 
 
 
 
data/datasets/local/spotify/manifest.json DELETED
@@ -1,27 +0,0 @@
1
- {
2
- "files": [
3
- "data-00000-of-00001.parquet"
4
- ],
5
- "data_schema": {
6
- "fields": {
7
- "artist": {
8
- "dtype": "string"
9
- },
10
- "song": {
11
- "dtype": "string"
12
- },
13
- "link": {
14
- "dtype": "string"
15
- },
16
- "text": {
17
- "dtype": "string"
18
- },
19
- "__line_number__": {
20
- "dtype": "int64"
21
- },
22
- "__rowid__": {
23
- "dtype": "string"
24
- }
25
- }
26
- }
27
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/datasets/local/spotify/text/.concepts/local/aliens/sbert-neg-100.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:93f390fafd0d0db4ae6ae80d30bfbf8eb0a80fa9332f77f30449d40a11df0936
3
- size 183363
 
 
 
 
data/datasets/local/spotify/text/.concepts/local/outer_space/sbert-neg-100.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3fc9ac4c9b8b8588e48ebabbe34598edb4431985d20e018225b84546b96ce2ea
3
- size 166637
 
 
 
 
data/datasets/local/spotify/text/.concepts/local/outerspace/sbert-neg-100.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f3432ea5dcfbe7f7a17c94a4cc0c09e3317b8a690456fdf3af3efa0dcaa6f4fc
3
- size 188685
 
 
 
 
data/datasets/local/spotify/text/.concepts/local/phone_addiction/sbert-neg-100.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f795fb8b5d52650bd9aa5c871ff5d480e95413cd0afb65822a634c02f6674825
3
- size 163242
 
 
 
 
data/datasets/local/spotify/text/sbert/data-00000-of-00001.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9796beb630cc3503f3c2ac9db8f71e4c1604570836d78bbf364e801cd427c39e
3
- size 2709987
 
 
 
 
data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/data-00000-of-00001.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1ba0fe68cc02849b0a20b7f72047c8e9cb8e5ef5b57b0cd642fa0b0be8a6e06
3
- size 3340135
 
 
 
 
data/datasets/local/spotify/text/sbert/embedding/local/outerspace/v34/signal_manifest.json DELETED
@@ -1,64 +0,0 @@
1
- {
2
- "files": [
3
- "data-00000-of-00001.parquet"
4
- ],
5
- "parquet_id": "local/outerspace/v34(text.sbert.*.embedding)",
6
- "data_schema": {
7
- "fields": {
8
- "__rowid__": {
9
- "dtype": "string"
10
- },
11
- "text": {
12
- "fields": {
13
- "sbert": {
14
- "repeated_field": {
15
- "fields": {
16
- "embedding": {
17
- "fields": {
18
- "local/outerspace/v34": {
19
- "dtype": "float32",
20
- "signal": {
21
- "signal_name": "concept_score",
22
- "embedding": "sbert",
23
- "namespace": "local",
24
- "concept_name": "outerspace",
25
- "draft": "main",
26
- "num_negative_examples": 100
27
- },
28
- "bins": [
29
- [
30
- "Not in concept",
31
- null,
32
- 0.5
33
- ],
34
- [
35
- "In concept",
36
- 0.5,
37
- null
38
- ]
39
- ]
40
- }
41
- }
42
- }
43
- }
44
- }
45
- }
46
- }
47
- }
48
- }
49
- },
50
- "signal": {
51
- "signal_name": "concept_score",
52
- "embedding": "sbert",
53
- "namespace": "local",
54
- "concept_name": "outerspace",
55
- "draft": "main",
56
- "num_negative_examples": 100
57
- },
58
- "enriched_path": [
59
- "text",
60
- "sbert",
61
- "*",
62
- "embedding"
63
- ]
64
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.keys.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d5df43291782b8c731d4ce56537946654c642a01dc9a4e37de394836362f6b45
3
- size 3727400
 
 
 
 
data/datasets/local/spotify/text/sbert/embeddings-00000-of-00001.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:94e10c23d7229541e1f60b791a659d13673b10a03649abf0ae092e0e18c5aee3
3
- size 170446976
 
 
 
 
data/datasets/local/spotify/text/sbert/signal_manifest.json DELETED
@@ -1,37 +0,0 @@
1
- {
2
- "files": [
3
- "data-00000-of-00001.parquet"
4
- ],
5
- "parquet_id": "sbert(text)",
6
- "data_schema": {
7
- "fields": {
8
- "__rowid__": {
9
- "dtype": "string"
10
- },
11
- "text": {
12
- "fields": {
13
- "sbert": {
14
- "repeated_field": {
15
- "fields": {
16
- "embedding": {
17
- "dtype": "embedding"
18
- }
19
- },
20
- "dtype": "string_span"
21
- },
22
- "signal": {
23
- "signal_name": "sbert"
24
- }
25
- }
26
- }
27
- }
28
- }
29
- },
30
- "signal": {
31
- "signal_name": "sbert"
32
- },
33
- "enriched_path": [
34
- "text"
35
- ],
36
- "embedding_filename_prefix": "embeddings-00000-of-00001"
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -18,6 +18,7 @@ cytoolz==0.12.1 ; python_version >= "3.9" and python_version < "3.10"
18
  dask==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
19
  datasets==2.13.1 ; python_version >= "3.9" and python_version < "3.10"
20
  decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
 
21
  dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
22
  distributed==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
23
  duckdb==0.8.1 ; python_version >= "3.9" and python_version < "3.10"
 
18
  dask==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
19
  datasets==2.13.1 ; python_version >= "3.9" and python_version < "3.10"
20
  decorator==5.1.1 ; python_version >= "3.9" and python_version < "3.10"
21
+ detect-secrets==1.4.0 ; python_version >= "3.9" and python_version < "3.10"
22
  dill==0.3.6 ; python_version >= "3.9" and python_version < "3.10"
23
  distributed==2023.6.1 ; python_version >= "3.9" and python_version < "3.10"
24
  duckdb==0.8.1 ; python_version >= "3.9" and python_version < "3.10"
src/concepts/concept.py CHANGED
@@ -162,7 +162,7 @@ class LogisticEmbeddingModel:
162
  def __post_init__(self) -> None:
163
  # See `notebooks/Toxicity.ipynb` for an example of training a concept model.
164
  self._model = LogisticRegression(
165
- class_weight=None, C=30, tol=1e-5, warm_start=True, max_iter=1_000, n_jobs=-1)
166
 
167
  def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
168
  """Get the scores for the provided embeddings."""
@@ -175,11 +175,12 @@ class LogisticEmbeddingModel:
175
  return np.random.rand(len(embeddings))
176
 
177
  def _setup_training(
178
- self, X_train: np.ndarray, y_train: list[bool],
179
  implicit_negatives: Optional[np.ndarray]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
180
- num_pos_labels = len([y for y in y_train if y])
181
- num_neg_labels = len([y for y in y_train if not y])
182
- sample_weights = [(1.0 / num_pos_labels if y else 1.0 / num_neg_labels) for y in y_train]
 
183
 
184
  if implicit_negatives is not None:
185
  num_implicit_labels = len(implicit_negatives)
@@ -191,7 +192,14 @@ class LogisticEmbeddingModel:
191
  # Normalize sample weights to sum to the number of training examples.
192
  weights = np.array(sample_weights)
193
  weights *= (X_train.shape[0] / np.sum(weights))
194
- return X_train, np.array(y_train), weights
 
 
 
 
 
 
 
195
 
196
  def fit(self, embeddings: np.ndarray, labels: list[bool],
197
  implicit_negatives: Optional[np.ndarray]) -> None:
@@ -337,11 +345,12 @@ class ConceptModel:
337
 
338
  embedding_items = list(embedding.compute(examples))
339
  result_items: list[Item] = []
 
340
  for item in embedding_items:
341
  if not isinstance(item, list):
342
  raise ValueError('Item from embedding is not a list.')
343
- embeddings = np.array([np.squeeze(res[EMBEDDING_KEY]) for res in item])
344
- scores = self._get_logistic_model(draft).score_embeddings(embeddings).tolist()
345
 
346
  item_result: list[Item] = []
347
  for embedding_item, score in zip(item, scores):
 
162
  def __post_init__(self) -> None:
163
  # See `notebooks/Toxicity.ipynb` for an example of training a concept model.
164
  self._model = LogisticRegression(
165
+ class_weight=None, C=30, tol=1e-5, warm_start=True, max_iter=5_000, n_jobs=-1)
166
 
167
  def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
168
  """Get the scores for the provided embeddings."""
 
175
  return np.random.rand(len(embeddings))
176
 
177
  def _setup_training(
178
+ self, X_train: np.ndarray, labels: list[bool],
179
  implicit_negatives: Optional[np.ndarray]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
180
+ num_pos_labels = len([y for y in labels if y])
181
+ num_neg_labels = len([y for y in labels if not y])
182
+ sample_weights = [(1.0 / num_pos_labels if y else 1.0 / num_neg_labels) for y in labels]
183
+ y_train = np.array(labels)
184
 
185
  if implicit_negatives is not None:
186
  num_implicit_labels = len(implicit_negatives)
 
192
  # Normalize sample weights to sum to the number of training examples.
193
  weights = np.array(sample_weights)
194
  weights *= (X_train.shape[0] / np.sum(weights))
195
+
196
+ # Shuffle the data in unison.
197
+ p = np.random.permutation(len(X_train))
198
+ X_train = X_train[p]
199
+ y_train = y_train[p]
200
+ weights = weights[p]
201
+
202
+ return X_train, y_train, weights
203
 
204
  def fit(self, embeddings: np.ndarray, labels: list[bool],
205
  implicit_negatives: Optional[np.ndarray]) -> None:
 
345
 
346
  embedding_items = list(embedding.compute(examples))
347
  result_items: list[Item] = []
348
+ logistic_model = self._get_logistic_model(draft)
349
  for item in embedding_items:
350
  if not isinstance(item, list):
351
  raise ValueError('Item from embedding is not a list.')
352
+ embeddings = np.array([np.reshape(res[EMBEDDING_KEY], -1) for res in item])
353
+ scores = logistic_model.score_embeddings(embeddings).tolist()
354
 
355
  item_result: list[Item] = []
356
  for embedding_item, score in zip(item, scores):
src/concepts/concept_test.py DELETED
@@ -1,84 +0,0 @@
1
- """Tests for concept."""
2
-
3
- from ..schema import SignalInputType
4
- from .concept import DRAFT_MAIN, Concept, Example, draft_examples
5
-
6
-
7
- def test_draft_examples_main() -> None:
8
- concept = Concept(
9
- namespace='test_namespace',
10
- concept_name='test_name',
11
- type=SignalInputType.TEXT,
12
- data={
13
- '0': Example(id='0', label=True, text='hello'),
14
- '1': Example(id='1', label=False, text='world'),
15
- },
16
- version=0)
17
-
18
- assert draft_examples(concept, DRAFT_MAIN) == {
19
- '0': Example(id='0', label=True, text='hello'),
20
- '1': Example(id='1', label=False, text='world'),
21
- }
22
-
23
-
24
- def test_draft_examples_simple_draft() -> None:
25
- concept = Concept(
26
- namespace='test_namespace',
27
- concept_name='test_name',
28
- type=SignalInputType.TEXT,
29
- data={
30
- '0': Example(id='0', label=True, text='hello'),
31
- '1': Example(id='1', label=False, text='world'),
32
- '2': Example(id='2', label=True, text='hello draft 1', draft='draft1'),
33
- '3': Example(id='3', label=False, text='world draft 1', draft='draft1'),
34
- '4': Example(id='4', label=True, text='hello draft 2', draft='draft2'),
35
- '5': Example(id='5', label=False, text='world draft 2', draft='draft2'),
36
- },
37
- version=0)
38
-
39
- assert draft_examples(concept, DRAFT_MAIN) == {
40
- '0': Example(id='0', label=True, text='hello'),
41
- '1': Example(id='1', label=False, text='world'),
42
- }
43
-
44
- assert draft_examples(concept, 'draft1') == {
45
- '0': Example(id='0', label=True, text='hello'),
46
- '1': Example(id='1', label=False, text='world'),
47
- '2': Example(id='2', label=True, text='hello draft 1', draft='draft1'),
48
- '3': Example(id='3', label=False, text='world draft 1', draft='draft1'),
49
- }
50
-
51
- assert draft_examples(concept, 'draft2') == {
52
- '0': Example(id='0', label=True, text='hello'),
53
- '1': Example(id='1', label=False, text='world'),
54
- '4': Example(id='4', label=True, text='hello draft 2', draft='draft2'),
55
- '5': Example(id='5', label=False, text='world draft 2', draft='draft2'),
56
- }
57
-
58
-
59
- def test_draft_examples_draft_dedupe() -> None:
60
- concept = Concept(
61
- namespace='test_namespace',
62
- concept_name='test_name',
63
- type=SignalInputType.TEXT,
64
- data={
65
- '0': Example(id='0', label=True, text='hello'),
66
- '1': Example(id='1', label=False, text='world'),
67
- # Duplicate text.
68
- '2': Example(id='2', label=False, text='hello', draft='draft'),
69
- '3': Example(id='3', label=False, text='world draft', draft='draft'),
70
- },
71
- version=0)
72
-
73
- assert draft_examples(concept, DRAFT_MAIN) == {
74
- '0': Example(id='0', label=True, text='hello'),
75
- '1': Example(id='1', label=False, text='world'),
76
- }
77
-
78
- assert draft_examples(concept, 'draft') == {
79
- # 0 is deduplicated with 2.
80
- '1': Example(id='1', label=False, text='world'),
81
- # 2 overrides 0's label.
82
- '2': Example(id='2', label=False, text='hello', draft='draft'),
83
- '3': Example(id='3', label=False, text='world draft', draft='draft'),
84
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/concepts/db_concept_test.py DELETED
@@ -1,606 +0,0 @@
1
- """Tests for the the database concept."""
2
-
3
- from pathlib import Path
4
- from typing import Generator, Iterable, Optional, Type, cast
5
-
6
- import numpy as np
7
- import pytest
8
- from pytest_mock import MockerFixture
9
- from typing_extensions import override
10
-
11
- from ..config import CONFIG
12
- from ..data.dataset_duckdb import DatasetDuckDB
13
- from ..data.dataset_utils import lilac_embedding
14
- from ..db_manager import set_default_dataset_cls
15
- from ..schema import Item, RichData, SignalInputType
16
- from ..signals.signal import TextEmbeddingSignal, clear_signal_registry, register_signal
17
- from .concept import (
18
- DRAFT_MAIN,
19
- Concept,
20
- ConceptModel,
21
- DraftId,
22
- Example,
23
- ExampleIn,
24
- LogisticEmbeddingModel,
25
- )
26
- from .db_concept import (
27
- ConceptDB,
28
- ConceptInfo,
29
- ConceptModelDB,
30
- ConceptUpdate,
31
- DiskConceptDB,
32
- DiskConceptModelDB,
33
- )
34
-
35
- ALL_CONCEPT_DBS = [DiskConceptDB]
36
- ALL_CONCEPT_MODEL_DBS = [DiskConceptModelDB]
37
-
38
-
39
- @pytest.fixture(autouse=True)
40
- def set_data_path(tmp_path: Path, mocker: MockerFixture) -> None:
41
- mocker.patch.dict(CONFIG, {'LILAC_DATA_PATH': str(tmp_path)})
42
-
43
-
44
- EMBEDDING_MAP: dict[str, list[float]] = {
45
- 'not in concept': [1.0, 0.0, 0.0],
46
- 'in concept': [0.9, 0.1, 0.0],
47
- 'a new data point': [0.1, 0.2, 0.3],
48
- 'a true draft point': [0.4, 0.5, 0.6],
49
- 'a false draft point': [0.7, 0.8, 0.9],
50
- }
51
-
52
-
53
- class TestEmbedding(TextEmbeddingSignal):
54
- """A test embed function."""
55
- name = 'test_embedding'
56
-
57
- @override
58
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
59
- """Embed the examples, use a hashmap to the vector for simplicity."""
60
- for example in data:
61
- if example not in EMBEDDING_MAP:
62
- raise ValueError(f'Example "{str(example)}" not in embedding map')
63
- yield [lilac_embedding(0, len(example), np.array(EMBEDDING_MAP[cast(str, example)]))]
64
-
65
-
66
- @pytest.fixture(scope='module', autouse=True)
67
- def setup_teardown() -> Generator:
68
- set_default_dataset_cls(DatasetDuckDB)
69
- register_signal(TestEmbedding)
70
-
71
- # Unit test runs.
72
- yield
73
-
74
- # Teardown.
75
- clear_signal_registry()
76
-
77
-
78
- @pytest.mark.parametrize('db_cls', ALL_CONCEPT_DBS)
79
- class ConceptDBSuite:
80
-
81
- def test_create_concept(self, db_cls: Type[ConceptDB]) -> None:
82
- db = db_cls()
83
- db.create(namespace='test', name='test_concept', type=SignalInputType.TEXT)
84
-
85
- assert db.list() == [
86
- ConceptInfo(
87
- namespace='test', name='test_concept', type=SignalInputType.TEXT, drafts=[DRAFT_MAIN])
88
- ]
89
-
90
- # Make sure list with drafts relects the drafts.
91
- train_data = [
92
- ExampleIn(label=False, text='not in concept', draft='test_draft'),
93
- ExampleIn(label=True, text='in concept', draft='test_draft')
94
- ]
95
- db.edit('test', 'test_concept', ConceptUpdate(insert=train_data))
96
-
97
- assert db.list() == [
98
- ConceptInfo(
99
- namespace='test',
100
- name='test_concept',
101
- type=SignalInputType.TEXT,
102
- drafts=[DRAFT_MAIN, 'test_draft'])
103
- ]
104
-
105
- def test_add_example(self, db_cls: Type[ConceptDB]) -> None:
106
- db = db_cls()
107
- namespace = 'test'
108
- concept_name = 'test_concept'
109
- train_data = [
110
- ExampleIn(label=False, text='not in concept'),
111
- ExampleIn(label=True, text='in concept')
112
- ]
113
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
114
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
115
-
116
- concept = db.get(namespace, concept_name)
117
-
118
- assert concept is not None
119
-
120
- keys = list(concept.data.keys())
121
- assert concept == Concept(
122
- namespace=namespace,
123
- concept_name=concept_name,
124
- type=SignalInputType.TEXT,
125
- data={
126
- keys[0]: Example(id=keys[0], label=False, text='not in concept'),
127
- keys[1]: Example(id=keys[1], label=True, text='in concept')
128
- },
129
- version=1)
130
-
131
- # Add a draft labels.
132
- db.edit(
133
- namespace, concept_name,
134
- ConceptUpdate(insert=[
135
- ExampleIn(label=False, text='really not in concept', draft='test_draft'),
136
- ExampleIn(label=True, text='really in concept', draft='test_draft')
137
- ]))
138
-
139
- concept = db.get(namespace, concept_name)
140
- assert concept is not None
141
-
142
- keys = list(concept.data.keys())
143
- assert concept == Concept(
144
- namespace=namespace,
145
- concept_name=concept_name,
146
- type=SignalInputType.TEXT,
147
- data={
148
- keys[0]: Example(id=keys[0], label=False, text='not in concept'),
149
- keys[1]: Example(id=keys[1], label=True, text='in concept'),
150
- keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'),
151
- keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'),
152
- },
153
- version=2)
154
-
155
- def test_update_concept(self, db_cls: Type[ConceptDB]) -> None:
156
- db = db_cls()
157
- namespace = 'test'
158
- concept_name = 'test_concept'
159
- train_data = [
160
- ExampleIn(label=False, text='not in concept'),
161
- ExampleIn(label=True, text='in concept'),
162
- ExampleIn(label=False, text='really not in concept', draft='test_draft'),
163
- ExampleIn(label=True, text='really in concept', draft='test_draft')
164
- ]
165
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
166
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
167
-
168
- concept = db.get(namespace, concept_name)
169
- assert concept is not None
170
-
171
- keys = list(concept.data.keys())
172
- # Edit the first example.
173
- db.edit(
174
- namespace, concept_name,
175
- ConceptUpdate(update=[Example(id=keys[0], label=False, text='not in concept, updated')]))
176
- concept = db.get(namespace, concept_name)
177
-
178
- assert concept == Concept(
179
- namespace=namespace,
180
- concept_name=concept_name,
181
- type=SignalInputType.TEXT,
182
- data={
183
- # The first example should be updated alone.
184
- keys[0]: Example(id=keys[0], label=False, text='not in concept, updated'),
185
- keys[1]: Example(id=keys[1], label=True, text='in concept'),
186
- # Drafts are untouched.
187
- keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'),
188
- keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'),
189
- },
190
- version=2)
191
-
192
- # Edit the second example on the draft.
193
- db.edit(
194
- namespace, concept_name,
195
- ConceptUpdate(update=[
196
- Example(id=keys[3], label=True, text='really in concept, updated', draft='test_draft')
197
- ]))
198
- concept = db.get(namespace, concept_name)
199
-
200
- assert concept == Concept(
201
- namespace=namespace,
202
- concept_name=concept_name,
203
- type=SignalInputType.TEXT,
204
- data={
205
- # Main remains the same.
206
- keys[0]: Example(id=keys[0], label=False, text='not in concept, updated'),
207
- keys[1]: Example(id=keys[1], label=True, text='in concept'),
208
- keys[2]: Example(id=keys[2], label=False, text='really not in concept', draft='test_draft'),
209
- keys[3]: Example(
210
- id=keys[3], label=True, text='really in concept, updated', draft='test_draft'),
211
- },
212
- version=3)
213
-
214
- def test_remove_concept(self, db_cls: Type[ConceptDB]) -> None:
215
- db = db_cls()
216
- namespace = 'test'
217
- concept_name = 'test_concept'
218
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
219
-
220
- train_data = [
221
- ExampleIn(label=False, text='not in concept'),
222
- ExampleIn(label=True, text='in concept')
223
- ]
224
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
225
- concept = db.get(namespace, concept_name)
226
-
227
- db.remove(namespace, concept_name)
228
-
229
- concept = db.get(namespace, concept_name)
230
-
231
- assert concept is None
232
-
233
- def test_remove_concept_examples(self, db_cls: Type[ConceptDB]) -> None:
234
- db = db_cls()
235
- namespace = 'test'
236
- concept_name = 'test_concept'
237
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
238
-
239
- train_data = [
240
- ExampleIn(label=False, text='not in concept'),
241
- ExampleIn(label=True, text='in concept')
242
- ]
243
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
244
- concept = db.get(namespace, concept_name)
245
- assert concept is not None
246
-
247
- keys = list(concept.data.keys())
248
-
249
- db.edit(namespace, concept_name, ConceptUpdate(remove=[keys[0]]))
250
- concept = db.get(namespace, concept_name)
251
-
252
- assert concept == Concept(
253
- namespace=namespace,
254
- concept_name=concept_name,
255
- type=SignalInputType.TEXT,
256
- data={
257
- # key_0 was removed.
258
- keys[1]: Example(id=keys[1], label=True, text='in concept')
259
- },
260
- version=2)
261
-
262
- def test_remove_concept_examples_draft(self, db_cls: Type[ConceptDB]) -> None:
263
- db = db_cls()
264
- namespace = 'test'
265
- concept_name = 'test_concept'
266
- train_data = [
267
- ExampleIn(label=False, text='not in concept'),
268
- ExampleIn(label=True, text='in concept'),
269
- ExampleIn(label=False, text='really not in concept', draft='test_draft'),
270
- ExampleIn(label=True, text='really in concept', draft='test_draft')
271
- ]
272
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
273
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
274
- concept = db.get(namespace, concept_name)
275
- assert concept is not None
276
-
277
- keys = list(concept.data.keys())
278
-
279
- db.edit(namespace, concept_name, ConceptUpdate(remove=[keys[2]]))
280
- concept = db.get(namespace, concept_name)
281
-
282
- assert concept == Concept(
283
- namespace=namespace,
284
- concept_name=concept_name,
285
- type=SignalInputType.TEXT,
286
- data={
287
- keys[0]: Example(id=keys[0], label=False, text='not in concept'),
288
- keys[1]: Example(id=keys[1], label=True, text='in concept'),
289
- # The first draft example is removed.
290
- keys[3]: Example(id=keys[3], label=True, text='really in concept', draft='test_draft'),
291
- },
292
- version=2)
293
-
294
- def test_remove_invalid_id(self, db_cls: Type[ConceptDB]) -> None:
295
- db = db_cls()
296
- namespace = 'test'
297
- concept_name = 'test_concept'
298
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
299
-
300
- train_data = [
301
- ExampleIn(label=False, text='not in concept'),
302
- ExampleIn(label=True, text='in concept'),
303
- ExampleIn(label=False, text='really not in concept', draft='test_draft'),
304
- ExampleIn(label=True, text='really in concept', draft='test_draft')
305
- ]
306
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
307
-
308
- with pytest.raises(ValueError, match='Example with id "invalid_id" does not exist'):
309
- db.edit(namespace, concept_name, ConceptUpdate(remove=['invalid_id']))
310
-
311
- def test_edit_before_creation(self, db_cls: Type[ConceptDB]) -> None:
312
- db = db_cls()
313
- namespace = 'test'
314
- concept_name = 'test_concept'
315
-
316
- with pytest.raises(
317
- ValueError, match='Concept with namespace "test" and name "test_concept" does not exist'):
318
- db.edit(namespace, concept_name,
319
- ConceptUpdate(insert=[
320
- ExampleIn(label=False, text='not in concept'),
321
- ]))
322
-
323
- def test_edit_invalid_id(self, db_cls: Type[ConceptDB]) -> None:
324
- db = db_cls()
325
- namespace = 'test'
326
- concept_name = 'test_concept'
327
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
328
-
329
- train_data = [
330
- ExampleIn(label=False, text='not in concept'),
331
- ExampleIn(label=True, text='in concept')
332
- ]
333
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
334
-
335
- with pytest.raises(ValueError, match='Example with id "invalid_id" does not exist'):
336
- db.edit(namespace, concept_name,
337
- ConceptUpdate(update=[Example(id='invalid_id', label=False, text='not in concept')]))
338
-
339
- def test_merge_draft(self, db_cls: Type[ConceptDB]) -> None:
340
- db = db_cls()
341
- namespace = 'test'
342
- concept_name = 'test_concept'
343
- db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
344
-
345
- train_data = [
346
- ExampleIn(label=True, text='hello'),
347
- ExampleIn(label=False, text='world'),
348
- ExampleIn(label=True, text='hello draft 1', draft='draft1'),
349
- ExampleIn(label=False, text='world draft 1', draft='draft1'),
350
- # Duplicate of main.
351
- ExampleIn(label=False, text='hello', draft='draft2'),
352
- ExampleIn(label=True, text='world draft 2', draft='draft2'),
353
- ]
354
- db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
355
-
356
- db.merge_draft(namespace, concept_name, 'draft1')
357
-
358
- concept = db.get(namespace, concept_name)
359
- assert concept is not None
360
- keys = list(concept.data.keys())
361
-
362
- assert concept.dict() == Concept(
363
- namespace='test',
364
- concept_name='test_concept',
365
- type=SignalInputType.TEXT,
366
- data={
367
- keys[0]: Example(id=keys[0], label=True, text='hello'),
368
- keys[1]: Example(id=keys[1], label=False, text='world'),
369
- # Draft examples are merged.
370
- keys[2]: Example(id=keys[2], label=True, text='hello draft 1'),
371
- keys[3]: Example(id=keys[3], label=False, text='world draft 1'),
372
- # Draft 2 is untouched.
373
- keys[4]: Example(id=keys[4], label=False, text='hello', draft='draft2'),
374
- keys[5]: Example(id=keys[5], label=True, text='world draft 2', draft='draft2'),
375
- },
376
- version=2).dict()
377
-
378
- db.merge_draft(namespace, concept_name, 'draft2')
379
-
380
- concept = db.get(namespace, concept_name)
381
- assert concept is not None
382
-
383
- assert concept == Concept(
384
- namespace='test',
385
- concept_name='test_concept',
386
- type=SignalInputType.TEXT,
387
- data={
388
- # The first example is a duplicate of the label from the draft, so it is removed.
389
- keys[1]: Example(id=keys[1], label=False, text='world'),
390
- # Draft examples are merged.
391
- keys[2]: Example(id=keys[2], label=True, text='hello draft 1'),
392
- keys[3]: Example(id=keys[3], label=False, text='world draft 1'),
393
- # Draft examples are merged.
394
- keys[4]: Example(id=keys[4], label=False, text='hello'),
395
- keys[5]: Example(id=keys[5], label=True, text='world draft 2'),
396
- },
397
- version=3)
398
-
399
-
400
- def _make_test_concept_model(
401
- concept_db: ConceptDB,
402
- logistic_models: dict[DraftId, LogisticEmbeddingModel] = {}) -> ConceptModel:
403
- namespace = 'test'
404
- concept_name = 'test_concept'
405
- concept_db.create(namespace=namespace, name=concept_name, type=SignalInputType.TEXT)
406
-
407
- train_data = [
408
- ExampleIn(label=False, text='not in concept'),
409
- ExampleIn(label=True, text='in concept')
410
- ]
411
- concept_db.edit(namespace, concept_name, ConceptUpdate(insert=train_data))
412
- model = ConceptModel(
413
- namespace='test', concept_name='test_concept', embedding_name='test_embedding')
414
- model._logistic_models = logistic_models
415
- return model
416
-
417
-
418
- class TestLogisticModel(LogisticEmbeddingModel):
419
-
420
- @override
421
- def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
422
- """Get the scores for the provided embeddings."""
423
- return np.array([.1])
424
-
425
- @override
426
- def fit(self, embeddings: np.ndarray, labels: list[bool],
427
- implicit_negatives: Optional[np.ndarray]) -> None:
428
- pass
429
-
430
-
431
- @pytest.mark.parametrize('concept_db_cls', ALL_CONCEPT_DBS)
432
- @pytest.mark.parametrize('model_db_cls', ALL_CONCEPT_MODEL_DBS)
433
- class ConceptModelDBSuite:
434
-
435
- def test_save_and_get_model(self, concept_db_cls: Type[ConceptDB],
436
- model_db_cls: Type[ConceptModelDB]) -> None:
437
- concept_db = concept_db_cls()
438
- model_db = model_db_cls(concept_db)
439
- model = _make_test_concept_model(concept_db)
440
- model_db.sync(model)
441
- retrieved_model = model_db.get(
442
- namespace='test', concept_name='test_concept', embedding_name='test_embedding')
443
- if not retrieved_model:
444
- retrieved_model = model_db.create(
445
- namespace='test', concept_name='test_concept', embedding_name='test_embedding')
446
- assert retrieved_model.namespace == model.namespace
447
- assert retrieved_model.concept_name == model.concept_name
448
- assert retrieved_model.embedding_name == model.embedding_name
449
- assert retrieved_model.version == model.version
450
- assert retrieved_model.column_info == model.column_info
451
-
452
- def test_sync_model(self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB],
453
- mocker: MockerFixture) -> None:
454
-
455
- concept_db = concept_db_cls()
456
- model_db = model_db_cls(concept_db)
457
- logistic_model = TestLogisticModel()
458
- score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
459
- fit_mock = mocker.spy(TestLogisticModel, 'fit')
460
-
461
- model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
462
-
463
- assert model_db.in_sync(model) is False
464
- assert score_embeddings_mock.call_count == 0
465
- assert fit_mock.call_count == 0
466
-
467
- model_db.sync(model)
468
-
469
- assert model_db.in_sync(model) is True
470
- assert score_embeddings_mock.call_count == 0
471
- assert fit_mock.call_count == 1
472
-
473
- def test_out_of_sync_model(self, concept_db_cls: Type[ConceptDB],
474
- model_db_cls: Type[ConceptModelDB], mocker: MockerFixture) -> None:
475
- concept_db = concept_db_cls()
476
- model_db = model_db_cls(concept_db)
477
- score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
478
- fit_mock = mocker.spy(TestLogisticModel, 'fit')
479
- logistic_model = TestLogisticModel()
480
- model = _make_test_concept_model(concept_db, logistic_models={DRAFT_MAIN: logistic_model})
481
- model_db.sync(model)
482
- assert model_db.in_sync(model) is True
483
- assert score_embeddings_mock.call_count == 0
484
- assert fit_mock.call_count == 1
485
-
486
- (called_model, called_embeddings, called_labels,
487
- called_implicit_negatives) = fit_mock.call_args_list[-1].args
488
- assert called_model == logistic_model
489
- np.testing.assert_array_equal(
490
- called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
491
- assert called_labels == [False, True]
492
- assert called_implicit_negatives is None
493
-
494
- # Edit the concept.
495
- concept_db.edit('test', 'test_concept',
496
- ConceptUpdate(insert=[ExampleIn(label=False, text='a new data point')]))
497
-
498
- # Make sure the model is out of sync.
499
- assert model_db.in_sync(model) is False
500
- assert score_embeddings_mock.call_count == 0
501
- assert fit_mock.call_count == 1
502
-
503
- model_db.sync(model)
504
- assert model_db.in_sync(model) is True
505
- assert score_embeddings_mock.call_count == 0
506
- assert fit_mock.call_count == 2
507
- # Fit is called again with new points on main only.
508
- (called_model, called_embeddings, called_labels,
509
- called_implicit_negatives) = fit_mock.call_args_list[-1].args
510
- assert called_model == logistic_model
511
- np.testing.assert_array_equal(
512
- called_embeddings,
513
- np.array([
514
- EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept'],
515
- EMBEDDING_MAP['a new data point']
516
- ]))
517
- assert called_labels == [False, True, False]
518
- assert called_implicit_negatives is None
519
-
520
- def test_out_of_sync_draft_model(self, concept_db_cls: Type[ConceptDB],
521
- model_db_cls: Type[ConceptModelDB],
522
- mocker: MockerFixture) -> None:
523
- concept_db = concept_db_cls()
524
- model_db = model_db_cls(concept_db)
525
- score_embeddings_mock = mocker.spy(TestLogisticModel, 'score_embeddings')
526
- fit_mock = mocker.spy(TestLogisticModel, 'fit')
527
- main_model = TestLogisticModel()
528
- draft_model = TestLogisticModel()
529
- model = _make_test_concept_model(
530
- concept_db, logistic_models={
531
- DRAFT_MAIN: main_model,
532
- 'test_draft': draft_model
533
- })
534
- model_db.sync(model)
535
- assert model_db.in_sync(model) is True
536
- assert score_embeddings_mock.call_count == 0
537
- assert fit_mock.call_count == 1
538
-
539
- # Make sure drafts cause the model to be out of sync.
540
- concept_db.edit(
541
- 'test',
542
- 'test_concept',
543
- ConceptUpdate(insert=[
544
- ExampleIn(label=True, text='a true draft point', draft='test_draft'),
545
- ExampleIn(label=False, text='a false draft point', draft='test_draft'),
546
- # This point exists in main, but we switched the label.
547
- ExampleIn(label=False, text='in concept', draft='test_draft'),
548
- ]))
549
-
550
- # Make sure the model is out of sync.
551
- assert model_db.in_sync(model) is False
552
- assert score_embeddings_mock.call_count == 0
553
- assert fit_mock.call_count == 1
554
-
555
- model_db.sync(model)
556
- assert model_db.in_sync(model) is True
557
- assert score_embeddings_mock.call_count == 0
558
- assert fit_mock.call_count == 3 # Fit is called on both the draft, and main.
559
-
560
- # Fit is called again with the same points.
561
- ((called_model, called_embeddings, called_labels, called_implicit_negatives),
562
- (called_draft_model, called_draft_embeddings, called_draft_labels,
563
- called_draft_implicit_negatives)) = (
564
- c.args for c in fit_mock.call_args_list[-2:])
565
-
566
- # The draft model is called with the data from main, and the data from draft.
567
- assert called_draft_model == draft_model
568
- np.testing.assert_array_equal(
569
- called_draft_embeddings,
570
- np.array([
571
- EMBEDDING_MAP['a true draft point'], EMBEDDING_MAP['a false draft point'],
572
- EMBEDDING_MAP['in concept'], EMBEDDING_MAP['not in concept']
573
- ]))
574
- assert called_draft_labels == [
575
- True,
576
- False,
577
- # This was overriden by the draft.
578
- False,
579
- False
580
- ]
581
- assert called_draft_implicit_negatives is None
582
-
583
- # The main model was fit without the data from the draft.
584
- assert called_model == main_model
585
- np.testing.assert_array_equal(
586
- called_embeddings, np.array([EMBEDDING_MAP['not in concept'], EMBEDDING_MAP['in concept']]))
587
- assert called_labels == [False, True]
588
- assert called_implicit_negatives is None
589
-
590
- def test_embedding_not_found_in_map(self, concept_db_cls: Type[ConceptDB],
591
- model_db_cls: Type[ConceptModelDB]) -> None:
592
- concept_db = concept_db_cls()
593
- model_db = model_db_cls(concept_db)
594
- model = _make_test_concept_model(concept_db)
595
- model_db.sync(model)
596
-
597
- # Edit the concept.
598
- concept_db.edit('test', 'test_concept',
599
- ConceptUpdate(insert=[ExampleIn(label=False, text='unknown text')]))
600
-
601
- # Make sure the model is out of sync.
602
- assert model_db.in_sync(model) is False
603
-
604
- with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
605
- model_db.sync(model)
606
- model_db.sync(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_compute_signal_chain_test.py DELETED
@@ -1,255 +0,0 @@
1
- """Tests for dataset.compute_signal() when signals are chained."""
2
-
3
- import re
4
- from typing import Iterable, List, Optional, cast
5
-
6
- import numpy as np
7
- import pytest
8
- from pytest_mock import MockerFixture
9
- from typing_extensions import override
10
-
11
- from ..embeddings.vector_store import VectorStore
12
- from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field, schema
13
- from ..signals.signal import (
14
- TextEmbeddingModelSignal,
15
- TextEmbeddingSignal,
16
- TextSignal,
17
- TextSplitterSignal,
18
- clear_signal_registry,
19
- register_signal,
20
- )
21
- from .dataset import DatasetManifest
22
- from .dataset_test_utils import (
23
- TEST_DATASET_NAME,
24
- TEST_NAMESPACE,
25
- TestDataMaker,
26
- enriched_embedding_span,
27
- enriched_embedding_span_field,
28
- enriched_item,
29
- )
30
- from .dataset_utils import lilac_embedding, lilac_span
31
-
32
- SIMPLE_ITEMS: list[Item] = [{
33
- UUID_COLUMN: '1',
34
- 'str': 'a',
35
- 'int': 1,
36
- 'bool': False,
37
- 'float': 3.0
38
- }, {
39
- UUID_COLUMN: '2',
40
- 'str': 'b',
41
- 'int': 2,
42
- 'bool': True,
43
- 'float': 2.0
44
- }, {
45
- UUID_COLUMN: '3',
46
- 'str': 'b',
47
- 'int': 2,
48
- 'bool': True,
49
- 'float': 1.0
50
- }]
51
-
52
- EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
53
- ('hello2.', [1.0, 1.0, 0.0]),
54
- ('hello world.', [1.0, 1.0, 1.0]),
55
- ('hello world2.', [2.0, 1.0, 1.0])]
56
-
57
- STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
58
-
59
-
60
- class TestSplitter(TextSplitterSignal):
61
- """Split documents into sentence by splitting on period."""
62
- name = 'test_splitter'
63
-
64
- @override
65
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
66
- for text in data:
67
- if not isinstance(text, str):
68
- raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
69
- sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence]
70
- yield [
71
- lilac_span(text.index(sentence),
72
- text.index(sentence) + len(sentence)) for sentence in sentences
73
- ]
74
-
75
-
76
- class TestEmbedding(TextEmbeddingSignal):
77
- """A test embed function."""
78
- name = 'test_embedding'
79
-
80
- @override
81
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
82
- """Call the embedding function."""
83
- for example in data:
84
- yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
85
-
86
-
87
- class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
88
- """Sums the embeddings to return a single floating point value."""
89
- name = 'test_embedding_sum'
90
-
91
- @override
92
- def fields(self) -> Field:
93
- return field('float32')
94
-
95
- @override
96
- def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
97
- # The signal just sums the values of the embedding.
98
- embedding_sums = vector_store.get(keys).sum(axis=1)
99
- for embedding_sum in embedding_sums.tolist():
100
- yield embedding_sum
101
-
102
-
103
- @pytest.fixture(scope='module', autouse=True)
104
- def setup_teardown() -> Iterable[None]:
105
- # Setup.
106
- register_signal(TestSplitter)
107
- register_signal(TestEmbedding)
108
- register_signal(TestEmbeddingSumSignal)
109
- register_signal(NamedEntity)
110
- # Unit test runs.
111
- yield
112
- # Teardown.
113
- clear_signal_registry()
114
-
115
-
116
- def test_manual_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
117
- dataset = make_test_data([{
118
- UUID_COLUMN: '1',
119
- 'text': 'hello.',
120
- }, {
121
- UUID_COLUMN: '2',
122
- 'text': 'hello2.',
123
- }])
124
-
125
- embed_mock = mocker.spy(TestEmbedding, 'compute')
126
-
127
- embedding_signal = TestEmbedding()
128
- dataset.compute_signal(embedding_signal, 'text')
129
- embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name)
130
- dataset.compute_signal(embedding_sum_signal, 'text')
131
-
132
- # Make sure the embedding signal is not called twice.
133
- assert embed_mock.call_count == 1
134
-
135
- assert dataset.manifest() == DatasetManifest(
136
- namespace=TEST_NAMESPACE,
137
- dataset_name=TEST_DATASET_NAME,
138
- data_schema=schema({
139
- UUID_COLUMN: 'string',
140
- 'text': field(
141
- 'string',
142
- fields={
143
- 'test_embedding': field(
144
- signal=embedding_signal.dict(),
145
- fields=[
146
- enriched_embedding_span_field(
147
- {'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
148
- ])
149
- }),
150
- }),
151
- num_items=2)
152
-
153
- result = dataset.select_rows()
154
- expected_result = [{
155
- UUID_COLUMN: '1',
156
- 'text': enriched_item(
157
- 'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]})
158
- }, {
159
- UUID_COLUMN: '2',
160
- 'text': enriched_item(
161
- 'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]})
162
- }]
163
- assert list(result) == expected_result
164
-
165
-
166
- def test_auto_embedding_signal(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
167
- dataset = make_test_data([{
168
- UUID_COLUMN: '1',
169
- 'text': 'hello.',
170
- }, {
171
- UUID_COLUMN: '2',
172
- 'text': 'hello2.',
173
- }])
174
-
175
- embed_mock = mocker.spy(TestEmbedding, 'compute')
176
-
177
- # The embedding is automatically computed from the TestEmbeddingSumSignal.
178
- embedding_sum_signal = TestEmbeddingSumSignal(embedding=TestEmbedding.name)
179
- dataset.compute_signal(embedding_sum_signal, 'text')
180
-
181
- # Make sure the embedding signal is not called twice.
182
- assert embed_mock.call_count == 1
183
-
184
- assert dataset.manifest() == DatasetManifest(
185
- namespace=TEST_NAMESPACE,
186
- dataset_name=TEST_DATASET_NAME,
187
- data_schema=schema({
188
- UUID_COLUMN: 'string',
189
- 'text': field(
190
- 'string',
191
- fields={
192
- 'test_embedding': field(
193
- signal=embedding_sum_signal._embedding_signal.dict(),
194
- fields=[
195
- enriched_embedding_span_field(
196
- {'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
197
- ])
198
- }),
199
- }),
200
- num_items=2)
201
-
202
- result = dataset.select_rows()
203
- expected_result = [{
204
- UUID_COLUMN: '1',
205
- 'text': enriched_item(
206
- 'hello.', {'test_embedding': [enriched_embedding_span(0, 6, {'test_embedding_sum': 1.0})]})
207
- }, {
208
- UUID_COLUMN: '2',
209
- 'text': enriched_item(
210
- 'hello2.', {'test_embedding': [enriched_embedding_span(0, 7, {'test_embedding_sum': 2.0})]})
211
- }]
212
- assert list(result) == expected_result
213
-
214
-
215
- ENTITY_REGEX = r'[A-Za-z]+@[A-Za-z]+'
216
-
217
-
218
- class NamedEntity(TextSignal):
219
- """Find special entities."""
220
- name = 'entity'
221
-
222
- @override
223
- def fields(self) -> Field:
224
- return field(fields=['string_span'])
225
-
226
- @override
227
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[List[Item]]]:
228
- for text in data:
229
- if not isinstance(text, str):
230
- yield None
231
- continue
232
- yield [lilac_span(m.start(0), m.end(0)) for m in re.finditer(ENTITY_REGEX, text)]
233
-
234
-
235
- def test_entity_on_split_signal(make_test_data: TestDataMaker) -> None:
236
- text = 'Hello nik@test. Here are some other entities like pii@gmail and all@lilac.'
237
- dataset = make_test_data([{UUID_COLUMN: '1', 'text': text}])
238
- entity = NamedEntity()
239
- dataset.compute_signal(TestSplitter(), 'text')
240
- dataset.compute_signal(entity, ('text', 'test_splitter', '*'))
241
-
242
- result = dataset.select_rows(['text'])
243
- assert list(result) == [{
244
- UUID_COLUMN: '1',
245
- 'text': enriched_item(
246
- text, {
247
- 'test_splitter': [
248
- lilac_span(0, 15, {'entity': [lilac_span(6, 14)]}),
249
- lilac_span(16, 74, {'entity': [
250
- lilac_span(50, 59),
251
- lilac_span(64, 73),
252
- ]}),
253
- ]
254
- })
255
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_compute_signal_test.py DELETED
@@ -1,669 +0,0 @@
1
- """Tests for dataset.compute_signal()."""
2
-
3
- from typing import Iterable, Optional, Union, cast
4
-
5
- import numpy as np
6
- import pytest
7
- from typing_extensions import override
8
-
9
- from ..concepts.concept import ExampleIn
10
- from ..concepts.db_concept import ConceptUpdate, DiskConceptDB
11
- from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, SignalInputType, field, schema
12
- from ..signals.concept_scorer import ConceptScoreSignal
13
- from ..signals.signal import (
14
- TextEmbeddingSignal,
15
- TextSignal,
16
- TextSplitterSignal,
17
- clear_signal_registry,
18
- register_signal,
19
- )
20
- from .dataset import Column, DatasetManifest, GroupsSortBy, SortOrder, val
21
- from .dataset_test_utils import (
22
- TEST_DATASET_NAME,
23
- TEST_NAMESPACE,
24
- TestDataMaker,
25
- enriched_embedding_span_field,
26
- enriched_item,
27
- )
28
- from .dataset_utils import lilac_embedding, lilac_span
29
-
30
- SIMPLE_ITEMS: list[Item] = [{
31
- UUID_COLUMN: '1',
32
- 'str': 'a',
33
- 'int': 1,
34
- 'bool': False,
35
- 'float': 3.0
36
- }, {
37
- UUID_COLUMN: '2',
38
- 'str': 'b',
39
- 'int': 2,
40
- 'bool': True,
41
- 'float': 2.0
42
- }, {
43
- UUID_COLUMN: '3',
44
- 'str': 'b',
45
- 'int': 2,
46
- 'bool': True,
47
- 'float': 1.0
48
- }]
49
-
50
-
51
- class TestInvalidSignal(TextSignal):
52
- name = 'test_invalid_signal'
53
-
54
- @override
55
- def fields(self) -> Field:
56
- return field('int32')
57
-
58
- @override
59
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
60
- # Return an invalid output that doesn't match the input length.
61
- return []
62
-
63
-
64
- class TestSparseSignal(TextSignal):
65
- name = 'test_sparse_signal'
66
-
67
- @override
68
- def fields(self) -> Field:
69
- return field('int32')
70
-
71
- @override
72
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
73
- for text in data:
74
- if text == 'hello':
75
- # Skip this input.
76
- yield None
77
- else:
78
- yield len(text)
79
-
80
-
81
- class TestSparseRichSignal(TextSignal):
82
- """Find personally identifiable information (emails, phone numbers, etc)."""
83
- name = 'test_sparse_rich_signal'
84
-
85
- @override
86
- def fields(self) -> Field:
87
- return field(fields={'emails': ['string']})
88
-
89
- @override
90
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
91
- for text in data:
92
- if text == 'hello':
93
- # Skip this input.
94
- yield None
95
- else:
96
- yield {'emails': ['[email protected]', '[email protected]']}
97
-
98
-
99
- class TestParamSignal(TextSignal):
100
- name = 'param_signal'
101
- param: str
102
-
103
- def fields(self) -> Field:
104
- return field('string')
105
-
106
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
107
- for text_content in data:
108
- yield f'{str(text_content)}_{self.param}'
109
-
110
-
111
- class TestSignal(TextSignal):
112
- name = 'test_signal'
113
-
114
- @override
115
- def fields(self) -> Field:
116
- return field(fields={'len': 'int32', 'flen': 'float32'})
117
-
118
- @override
119
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
120
- return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data]
121
-
122
-
123
- class TestSplitSignal(TextSplitterSignal):
124
- """Split documents into sentence by splitting on period, generating entities."""
125
- name = 'test_split'
126
-
127
- @override
128
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
129
- for text in data:
130
- if not isinstance(text, str):
131
- raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
132
- sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence]
133
- yield [
134
- lilac_span(text.index(sentence),
135
- text.index(sentence) + len(sentence)) for sentence in sentences
136
- ]
137
-
138
-
139
- EMBEDDINGS: list[tuple[str, Union[list[float], list[list[float]]]]] = [
140
- ('hello.', [1.0, 0.0, 0.0]),
141
- # This embedding has an outer dimension of 1.
142
- ('hello2.', [[1.0, 1.0, 0.0]]),
143
- ('hello3.', [[0, 0, 1.]])
144
- ]
145
-
146
- STR_EMBEDDINGS: dict[str, Union[list[float], list[list[float]]]] = {
147
- text: embedding for text, embedding in EMBEDDINGS
148
- }
149
-
150
-
151
- class TestEmbedding(TextEmbeddingSignal):
152
- """A test embed function."""
153
- name = 'test_embedding'
154
-
155
- @override
156
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
157
- """Call the embedding function."""
158
- for example in data:
159
- example = cast(str, example)
160
- yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[example]))]
161
-
162
-
163
- class ComputedKeySignal(TextSignal):
164
- name = 'computed_key'
165
-
166
- @override
167
- def fields(self) -> Field:
168
- return field('int64')
169
-
170
- @override
171
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
172
- for text in data:
173
- yield 1
174
-
175
- def key(self, is_computed_signal: Optional[bool] = False) -> str:
176
- return f'key_{is_computed_signal}'
177
-
178
-
179
- @pytest.fixture(scope='module', autouse=True)
180
- def setup_teardown() -> Iterable[None]:
181
- # Setup.
182
- register_signal(TestSparseSignal)
183
- register_signal(TestSparseRichSignal)
184
- register_signal(TestParamSignal)
185
- register_signal(TestSignal)
186
- register_signal(TestSplitSignal)
187
- register_signal(TestEmbedding)
188
- register_signal(ComputedKeySignal)
189
- register_signal(ConceptScoreSignal)
190
-
191
- # Unit test runs.
192
- yield
193
- # Teardown.
194
- clear_signal_registry()
195
-
196
-
197
- def test_signal_output_validation(make_test_data: TestDataMaker) -> None:
198
- signal = TestInvalidSignal()
199
-
200
- dataset = make_test_data([{
201
- UUID_COLUMN: '1',
202
- 'text': 'hello',
203
- }, {
204
- UUID_COLUMN: '2',
205
- 'text': 'hello world',
206
- }])
207
-
208
- with pytest.raises(
209
- ValueError, match='The signal generated 0 values but the input data had 2 values.'):
210
- dataset.compute_signal(signal, 'text')
211
-
212
-
213
- def test_sparse_signal(make_test_data: TestDataMaker) -> None:
214
- dataset = make_test_data([{
215
- UUID_COLUMN: '1',
216
- 'text': 'hello',
217
- }, {
218
- UUID_COLUMN: '2',
219
- 'text': 'hello world',
220
- }])
221
-
222
- dataset.compute_signal(TestSparseSignal(), 'text')
223
-
224
- result = dataset.select_rows(['text'])
225
- assert list(result) == [{
226
- UUID_COLUMN: '1',
227
- 'text': enriched_item('hello', {'test_sparse_signal': None})
228
- }, {
229
- UUID_COLUMN: '2',
230
- 'text': enriched_item('hello world', {'test_sparse_signal': 11})
231
- }]
232
-
233
-
234
- def test_sparse_rich_signal(make_test_data: TestDataMaker) -> None:
235
- dataset = make_test_data([{
236
- UUID_COLUMN: '1',
237
- 'text': 'hello',
238
- }, {
239
- UUID_COLUMN: '2',
240
- 'text': 'hello world',
241
- }])
242
-
243
- dataset.compute_signal(TestSparseRichSignal(), 'text')
244
-
245
- result = dataset.select_rows(['text'])
246
- assert list(result) == [{
247
- UUID_COLUMN: '1',
248
- 'text': enriched_item('hello', {'test_sparse_rich_signal': None})
249
- }, {
250
- UUID_COLUMN: '2',
251
- 'text': enriched_item(
252
- 'hello world',
253
- {'test_sparse_rich_signal': {
254
255
- }})
256
- }]
257
-
258
-
259
- def test_source_joined_with_signal(make_test_data: TestDataMaker) -> None:
260
- dataset = make_test_data(SIMPLE_ITEMS)
261
- assert dataset.manifest() == DatasetManifest(
262
- namespace=TEST_NAMESPACE,
263
- dataset_name=TEST_DATASET_NAME,
264
- data_schema=schema({
265
- UUID_COLUMN: 'string',
266
- 'str': 'string',
267
- 'int': 'int32',
268
- 'bool': 'boolean',
269
- 'float': 'float32',
270
- }),
271
- num_items=3)
272
-
273
- test_signal = TestSignal()
274
- dataset.compute_signal(test_signal, 'str')
275
-
276
- # Check the enriched dataset manifest has 'text' enriched.
277
- assert dataset.manifest() == DatasetManifest(
278
- namespace=TEST_NAMESPACE,
279
- dataset_name=TEST_DATASET_NAME,
280
- data_schema=schema({
281
- UUID_COLUMN: 'string',
282
- 'str': field(
283
- 'string',
284
- fields={
285
- 'test_signal': field(
286
- signal=test_signal.dict(), fields={
287
- 'len': 'int32',
288
- 'flen': 'float32'
289
- }),
290
- }),
291
- 'int': 'int32',
292
- 'bool': 'boolean',
293
- 'float': 'float32',
294
- }),
295
- num_items=3)
296
-
297
- result = dataset.select_rows(['str'])
298
- assert list(result) == [{
299
- UUID_COLUMN: '1',
300
- 'str': enriched_item('a', {'test_signal': {
301
- 'len': 1,
302
- 'flen': 1.0
303
- }}),
304
- }, {
305
- UUID_COLUMN: '2',
306
- 'str': enriched_item('b', {'test_signal': {
307
- 'len': 1,
308
- 'flen': 1.0
309
- }}),
310
- }, {
311
- UUID_COLUMN: '3',
312
- 'str': enriched_item('b', {'test_signal': {
313
- 'len': 1,
314
- 'flen': 1.0
315
- }}),
316
- }]
317
-
318
- # Select a specific signal leaf test_signal.flen with val('str').
319
- result = dataset.select_rows([val('str'), ('str', 'test_signal', 'flen')])
320
-
321
- assert list(result) == [{
322
- UUID_COLUMN: '1',
323
- f'str.{VALUE_KEY}': 'a',
324
- 'str.test_signal.flen': 1.0
325
- }, {
326
- UUID_COLUMN: '2',
327
- f'str.{VALUE_KEY}': 'b',
328
- 'str.test_signal.flen': 1.0
329
- }, {
330
- UUID_COLUMN: '3',
331
- f'str.{VALUE_KEY}': 'b',
332
- 'str.test_signal.flen': 1.0
333
- }]
334
-
335
- # Select a specific signal leaf test_signal.flen and the whole 'str' subtree.
336
- result = dataset.select_rows(['str', ('str', 'test_signal', 'flen')])
337
-
338
- assert list(result) == [{
339
- UUID_COLUMN: '1',
340
- 'str': enriched_item('a', {'test_signal': {
341
- 'len': 1,
342
- 'flen': 1.0
343
- }}),
344
- 'str.test_signal.flen': 1.0
345
- }, {
346
- UUID_COLUMN: '2',
347
- 'str': enriched_item('b', {'test_signal': {
348
- 'len': 1,
349
- 'flen': 1.0
350
- }}),
351
- 'str.test_signal.flen': 1.0
352
- }, {
353
- UUID_COLUMN: '3',
354
- 'str': enriched_item('b', {'test_signal': {
355
- 'len': 1,
356
- 'flen': 1.0
357
- }}),
358
- 'str.test_signal.flen': 1.0
359
- }]
360
-
361
- # Select multiple signal leafs with aliasing.
362
- result = dataset.select_rows([
363
- val('str'),
364
- Column(('str', 'test_signal', 'flen'), alias='flen'),
365
- Column(('str', 'test_signal', 'len'), alias='len')
366
- ])
367
-
368
- assert list(result) == [{
369
- UUID_COLUMN: '1',
370
- f'str.{VALUE_KEY}': 'a',
371
- 'flen': 1.0,
372
- 'len': 1
373
- }, {
374
- UUID_COLUMN: '2',
375
- f'str.{VALUE_KEY}': 'b',
376
- 'flen': 1.0,
377
- 'len': 1
378
- }, {
379
- UUID_COLUMN: '3',
380
- f'str.{VALUE_KEY}': 'b',
381
- 'flen': 1.0,
382
- 'len': 1
383
- }]
384
-
385
-
386
- def test_parameterized_signal(make_test_data: TestDataMaker) -> None:
387
- dataset = make_test_data([{
388
- UUID_COLUMN: '1',
389
- 'text': 'hello'
390
- }, {
391
- UUID_COLUMN: '2',
392
- 'text': 'everybody'
393
- }])
394
- test_signal_a = TestParamSignal(param='a')
395
- test_signal_b = TestParamSignal(param='b')
396
- dataset.compute_signal(test_signal_a, 'text')
397
- dataset.compute_signal(test_signal_b, 'text')
398
-
399
- assert dataset.manifest() == DatasetManifest(
400
- namespace=TEST_NAMESPACE,
401
- dataset_name=TEST_DATASET_NAME,
402
- data_schema=schema({
403
- UUID_COLUMN: 'string',
404
- 'text': field(
405
- 'string',
406
- fields={
407
- 'param_signal(param=a)': field('string', test_signal_a.dict()),
408
- 'param_signal(param=b)': field('string', test_signal_b.dict()),
409
- }),
410
- }),
411
- num_items=2)
412
-
413
- result = dataset.select_rows(['text'])
414
- assert list(result) == [{
415
- UUID_COLUMN: '1',
416
- 'text': enriched_item('hello', {
417
- 'param_signal(param=a)': 'hello_a',
418
- 'param_signal(param=b)': 'hello_b',
419
- })
420
- }, {
421
- UUID_COLUMN: '2',
422
- 'text': enriched_item('everybody', {
423
- 'param_signal(param=a)': 'everybody_a',
424
- 'param_signal(param=b)': 'everybody_b',
425
- })
426
- }]
427
-
428
-
429
- def test_split_signal(make_test_data: TestDataMaker) -> None:
430
- dataset = make_test_data([{
431
- UUID_COLUMN: '1',
432
- 'text': '[1, 1] first sentence. [1, 1] second sentence.',
433
- }, {
434
- UUID_COLUMN: '2',
435
- 'text': 'b2 [2, 1] first sentence. [2, 1] second sentence.',
436
- }])
437
-
438
- signal = TestSplitSignal()
439
- dataset.compute_signal(signal, 'text')
440
-
441
- assert dataset.manifest() == DatasetManifest(
442
- namespace=TEST_NAMESPACE,
443
- dataset_name=TEST_DATASET_NAME,
444
- data_schema=schema({
445
- UUID_COLUMN: 'string',
446
- 'text': field(
447
- 'string', fields={'test_split': field(signal=signal.dict(), fields=[field('string_span')])})
448
- }),
449
- num_items=2)
450
-
451
- result = dataset.select_rows(['text'])
452
- expected_result = [{
453
- UUID_COLUMN: '1',
454
- 'text': enriched_item('[1, 1] first sentence. [1, 1] second sentence.',
455
- {'test_split': [lilac_span(0, 22), lilac_span(23, 46)]})
456
- }, {
457
- UUID_COLUMN: '2',
458
- 'text': enriched_item('b2 [2, 1] first sentence. [2, 1] second sentence.',
459
- {'test_split': [
460
- lilac_span(0, 25),
461
- lilac_span(26, 49),
462
- ]})
463
- }]
464
- assert list(result) == expected_result
465
-
466
-
467
- def test_signal_on_repeated_field(make_test_data: TestDataMaker) -> None:
468
- dataset = make_test_data([{
469
- UUID_COLUMN: '1',
470
- 'text': ['hello', 'everybody'],
471
- }, {
472
- UUID_COLUMN: '2',
473
- 'text': ['hello2', 'everybody2'],
474
- }])
475
- test_signal = TestSignal()
476
- # Run the signal on the repeated field.
477
- dataset.compute_signal(test_signal, ('text', '*'))
478
-
479
- # Check the enriched dataset manifest has 'text' enriched.
480
- assert dataset.manifest() == DatasetManifest(
481
- namespace=TEST_NAMESPACE,
482
- dataset_name=TEST_DATASET_NAME,
483
- data_schema=schema({
484
- UUID_COLUMN: 'string',
485
- 'text': field(fields=[
486
- field(
487
- 'string',
488
- fields={
489
- 'test_signal': field(
490
- signal=test_signal.dict(), fields={
491
- 'len': 'int32',
492
- 'flen': 'float32'
493
- })
494
- })
495
- ])
496
- }),
497
- num_items=2)
498
-
499
- result = dataset.select_rows([('text', '*')])
500
-
501
- assert list(result) == [{
502
- UUID_COLUMN: '1',
503
- 'text.*': [
504
- enriched_item('hello', {'test_signal': {
505
- 'len': 5,
506
- 'flen': 5.0
507
- }}),
508
- enriched_item('everybody', {'test_signal': {
509
- 'len': 9,
510
- 'flen': 9.0
511
- }})
512
- ]
513
- }, {
514
- UUID_COLUMN: '2',
515
- 'text.*': [
516
- enriched_item('hello2', {'test_signal': {
517
- 'len': 6,
518
- 'flen': 6.0
519
- }}),
520
- enriched_item('everybody2', {'test_signal': {
521
- 'len': 10,
522
- 'flen': 10.0
523
- }})
524
- ]
525
- }]
526
-
527
-
528
- def test_text_splitter(make_test_data: TestDataMaker) -> None:
529
- dataset = make_test_data([{
530
- UUID_COLUMN: '1',
531
- 'text': '[1, 1] first sentence. [1, 1] second sentence.',
532
- }, {
533
- UUID_COLUMN: '2',
534
- 'text': 'b2 [2, 1] first sentence. [2, 1] second sentence.',
535
- }])
536
-
537
- dataset.compute_signal(TestSplitSignal(), 'text')
538
-
539
- result = dataset.select_rows(['text'])
540
- expected_result = [{
541
- UUID_COLUMN: '1',
542
- 'text': enriched_item('[1, 1] first sentence. [1, 1] second sentence.',
543
- {'test_split': [
544
- lilac_span(0, 22),
545
- lilac_span(23, 46),
546
- ]}),
547
- }, {
548
- UUID_COLUMN: '2',
549
- 'text': enriched_item('b2 [2, 1] first sentence. [2, 1] second sentence.',
550
- {'test_split': [
551
- lilac_span(0, 25),
552
- lilac_span(26, 49),
553
- ]}),
554
- }]
555
- assert list(result) == expected_result
556
-
557
-
558
- def test_embedding_signal(make_test_data: TestDataMaker) -> None:
559
- dataset = make_test_data([{
560
- UUID_COLUMN: '1',
561
- 'text': 'hello.',
562
- }, {
563
- UUID_COLUMN: '2',
564
- 'text': 'hello2.',
565
- }])
566
-
567
- embedding_signal = TestEmbedding()
568
- dataset.compute_signal(embedding_signal, 'text')
569
-
570
- assert dataset.manifest() == DatasetManifest(
571
- namespace=TEST_NAMESPACE,
572
- dataset_name=TEST_DATASET_NAME,
573
- data_schema=schema({
574
- UUID_COLUMN: 'string',
575
- 'text': field(
576
- 'string',
577
- fields={
578
- 'test_embedding': field(
579
- signal=embedding_signal.dict(), fields=[enriched_embedding_span_field()])
580
- }),
581
- }),
582
- num_items=2)
583
-
584
- result = dataset.select_rows()
585
-
586
- # Embeddings are replaced with "None".
587
- expected_result = [{
588
- UUID_COLUMN: '1',
589
- 'text': enriched_item('hello.', {'test_embedding': [lilac_embedding(0, 6, None)]})
590
- }, {
591
- UUID_COLUMN: '2',
592
- 'text': enriched_item('hello2.', {'test_embedding': [lilac_embedding(0, 7, None)]})
593
- }]
594
- assert list(result) == expected_result
595
-
596
-
597
- def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None:
598
- dataset = make_test_data([{
599
- UUID_COLUMN: '1',
600
- 'text': 'hello.',
601
- }, {
602
- UUID_COLUMN: '2',
603
- 'text': 'hello2.',
604
- }])
605
-
606
- signal = ComputedKeySignal()
607
- dataset.compute_signal(signal, 'text')
608
-
609
- assert dataset.manifest() == DatasetManifest(
610
- namespace=TEST_NAMESPACE,
611
- dataset_name=TEST_DATASET_NAME,
612
- data_schema=schema({
613
- UUID_COLUMN: 'string',
614
- 'text': field('string', fields={'key_True': field('int64', signal=signal.dict())}),
615
- }),
616
- num_items=2)
617
-
618
- result = dataset.select_rows()
619
-
620
- # Embeddings are replaced with "None".
621
- expected_result = [{
622
- UUID_COLUMN: '1',
623
- 'text': enriched_item('hello.', {'key_True': 1})
624
- }, {
625
- UUID_COLUMN: '2',
626
- 'text': enriched_item('hello2.', {'key_True': 1})
627
- }]
628
- assert list(result) == expected_result
629
-
630
-
631
- def test_concept_signal_with_select_groups(make_test_data: TestDataMaker) -> None:
632
- dataset = make_test_data([{
633
- UUID_COLUMN: '1',
634
- 'text': 'hello.',
635
- }, {
636
- UUID_COLUMN: '2',
637
- 'text': 'hello2.',
638
- }, {
639
- UUID_COLUMN: '3',
640
- 'text': 'hello3.',
641
- }])
642
-
643
- embedding_signal = TestEmbedding()
644
- dataset.compute_signal(embedding_signal, 'text')
645
-
646
- concept_db = DiskConceptDB()
647
- concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT)
648
- concept_db.edit(
649
- 'test_namespace', 'test_concept',
650
- ConceptUpdate(insert=[
651
- ExampleIn(label=False, text='hello.'),
652
- ExampleIn(label=True, text='hello2.'),
653
- ExampleIn(label=False, text='hello3.')
654
- ]))
655
-
656
- concept_signal = ConceptScoreSignal(
657
- namespace='test_namespace', concept_name='test_concept', embedding='test_embedding')
658
-
659
- dataset.compute_signal(concept_signal, 'text')
660
-
661
- concept_key = concept_signal.key(is_computed_signal=True)
662
- result = dataset.select_groups(f'text.test_embedding.*.embedding.{concept_key}')
663
- assert result.counts == [('Not in concept', 2), ('In concept', 1)]
664
-
665
- result = dataset.select_groups(
666
- f'text.test_embedding.*.embedding.{concept_key}',
667
- sort_by=GroupsSortBy.COUNT,
668
- sort_order=SortOrder.ASC)
669
- assert result.counts == [('In concept', 1), ('Not in concept', 2)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_duckdb.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import re
7
  import shutil
8
  import threading
9
- from typing import Any, Iterable, Optional, Sequence, Type, Union, cast
10
 
11
  import duckdb
12
  import numpy as np
@@ -93,6 +93,7 @@ from .dataset_utils import (
93
  read_embedding_index,
94
  replace_embeddings_with_none,
95
  schema_contains_path,
 
96
  unflatten,
97
  wrap_in_dicts,
98
  write_item_embeddings_to_disk,
@@ -686,7 +687,7 @@ class DatasetDuckDB(Dataset):
686
  star_in_cols = any(col.path == ('*',) for col in cols)
687
  if not cols or star_in_cols:
688
  # Select all columns.
689
- cols.extend([Column(name) for name in schema.fields.keys()])
690
  if star_in_cols:
691
  cols = [col for col in cols if col.path != ('*',)]
692
  return cols
@@ -941,8 +942,9 @@ class DatasetDuckDB(Dataset):
941
  # The input is an embedding.
942
  embedding_signal = cast(TextEmbeddingModelSignal, signal)
943
  vector_store = self.get_vector_store(embedding_signal.embedding, udf_col.path)
944
- flat_keys = flatten_keys(df[UUID_COLUMN], input)
945
- signal_out = signal.vector_compute(flat_keys, vector_store)
 
946
  # Add progress.
947
  if task_step_id is not None:
948
  signal_out = progress(
@@ -953,8 +955,9 @@ class DatasetDuckDB(Dataset):
953
  df[signal_column] = unflatten(signal_out, input)
954
  else:
955
  num_rich_data = count_primitives(input)
956
- flat_input = cast(Iterable[RichData], flatten(input))
957
- signal_out = signal.compute(flat_input)
 
958
  # Add progress.
959
  if task_step_id is not None:
960
  signal_out = progress(
@@ -962,22 +965,21 @@ class DatasetDuckDB(Dataset):
962
  task_step_id=task_step_id,
963
  estimated_len=num_rich_data,
964
  step_description=f'Computing {signal.key()}...')
965
- signal_out = list(signal_out)
966
-
967
  if signal_column in temp_column_to_offset_column:
968
  offset_column_name, field = temp_column_to_offset_column[signal_column]
969
- nested_spans: Iterable[Item] = df[offset_column_name]
970
  flat_spans = list(flatten(nested_spans))
971
- for span, item in zip(flat_spans, signal_out):
972
  _offset_any_span(cast(int, span[VALUE_KEY][TEXT_SPAN_START_FEATURE]), item, field)
973
 
974
- if len(signal_out) != num_rich_data:
975
  raise ValueError(
976
- f'The signal generated {len(signal_out)} values but the input data had '
977
  f"{num_rich_data} values. This means the signal either didn't generate a "
978
  '"None" for a sparse output, or generated too many items.')
979
 
980
- df[signal_column] = unflatten(signal_out, input)
981
 
982
  signal.teardown()
983
 
 
6
  import re
7
  import shutil
8
  import threading
9
+ from typing import Any, Iterable, Iterator, Optional, Sequence, Type, Union, cast
10
 
11
  import duckdb
12
  import numpy as np
 
93
  read_embedding_index,
94
  replace_embeddings_with_none,
95
  schema_contains_path,
96
+ sparse_to_dense_compute,
97
  unflatten,
98
  wrap_in_dicts,
99
  write_item_embeddings_to_disk,
 
687
  star_in_cols = any(col.path == ('*',) for col in cols)
688
  if not cols or star_in_cols:
689
  # Select all columns.
690
+ cols.extend([Column((name,)) for name in schema.fields.keys()])
691
  if star_in_cols:
692
  cols = [col for col in cols if col.path != ('*',)]
693
  return cols
 
942
  # The input is an embedding.
943
  embedding_signal = cast(TextEmbeddingModelSignal, signal)
944
  vector_store = self.get_vector_store(embedding_signal.embedding, udf_col.path)
945
+ flat_keys = list(flatten_keys(df[UUID_COLUMN], input))
946
+ signal_out = sparse_to_dense_compute(
947
+ iter(flat_keys), lambda keys: signal.vector_compute(keys, vector_store))
948
  # Add progress.
949
  if task_step_id is not None:
950
  signal_out = progress(
 
955
  df[signal_column] = unflatten(signal_out, input)
956
  else:
957
  num_rich_data = count_primitives(input)
958
+ flat_input = cast(Iterator[Optional[RichData]], flatten(input))
959
+ signal_out = sparse_to_dense_compute(
960
+ flat_input, lambda x: signal.compute(cast(Iterable[RichData], x)))
961
  # Add progress.
962
  if task_step_id is not None:
963
  signal_out = progress(
 
965
  task_step_id=task_step_id,
966
  estimated_len=num_rich_data,
967
  step_description=f'Computing {signal.key()}...')
968
+ signal_out_list = list(signal_out)
 
969
  if signal_column in temp_column_to_offset_column:
970
  offset_column_name, field = temp_column_to_offset_column[signal_column]
971
+ nested_spans: Iterator[Item] = df[offset_column_name]
972
  flat_spans = list(flatten(nested_spans))
973
+ for span, item in zip(flat_spans, signal_out_list):
974
  _offset_any_span(cast(int, span[VALUE_KEY][TEXT_SPAN_START_FEATURE]), item, field)
975
 
976
+ if len(signal_out_list) != num_rich_data:
977
  raise ValueError(
978
+ f'The signal generated {len(signal_out_list)} values but the input data had '
979
  f"{num_rich_data} values. This means the signal either didn't generate a "
980
  '"None" for a sparse output, or generated too many items.')
981
 
982
+ df[signal_column] = unflatten(signal_out_list, input)
983
 
984
  signal.teardown()
985
 
src/data/dataset_select_groups_test.py DELETED
@@ -1,317 +0,0 @@
1
- """Tests for dataset.select_groups()."""
2
-
3
- import re
4
-
5
- import pytest
6
- from pytest_mock import MockerFixture
7
-
8
- from ..schema import UUID_COLUMN, Item, field, schema
9
- from . import dataset as dataset_module
10
- from .dataset import BinaryOp
11
- from .dataset_test_utils import TestDataMaker
12
-
13
-
14
- def test_flat_data(make_test_data: TestDataMaker) -> None:
15
- items: list[Item] = [
16
- {
17
- 'name': 'Name1',
18
- 'age': 34,
19
- 'active': False
20
- },
21
- {
22
- 'name': 'Name2',
23
- 'age': 45,
24
- 'active': True
25
- },
26
- {
27
- 'age': 17,
28
- 'active': True
29
- }, # Missing "name".
30
- {
31
- 'name': 'Name3',
32
- 'active': True
33
- }, # Missing "age".
34
- {
35
- 'name': 'Name4',
36
- 'age': 55
37
- } # Missing "active".
38
- ]
39
- dataset = make_test_data(items)
40
-
41
- result = dataset.select_groups(leaf_path='name')
42
- assert result.counts == [('Name1', 1), ('Name2', 1), (None, 1), ('Name3', 1), ('Name4', 1)]
43
-
44
- result = dataset.select_groups(leaf_path='age', bins=[20, 50, 60])
45
- assert result.counts == [('1', 2), ('0', 1), (None, 1), ('2', 1)]
46
-
47
- result = dataset.select_groups(leaf_path='active')
48
- assert result.counts == [
49
- (True, 3),
50
- (False, 1),
51
- (None, 1), # Missing "active".
52
- ]
53
-
54
-
55
- def test_result_counts(make_test_data: TestDataMaker) -> None:
56
- items: list[Item] = [
57
- {
58
- 'active': False
59
- },
60
- {
61
- 'active': True
62
- },
63
- {
64
- 'active': True
65
- },
66
- {
67
- 'active': True
68
- },
69
- {} # Missing "active".
70
- ]
71
- dataset = make_test_data(items, schema=schema({UUID_COLUMN: 'string', 'active': 'boolean'}))
72
-
73
- result = dataset.select_groups(leaf_path='active')
74
- assert result.counts == [(True, 3), (False, 1), (None, 1)]
75
-
76
-
77
- def test_list_of_structs(make_test_data: TestDataMaker) -> None:
78
- items: list[Item] = [{
79
- 'list_of_structs': [{
80
- 'name': 'a'
81
- }, {
82
- 'name': 'b'
83
- }]
84
- }, {
85
- 'list_of_structs': [{
86
- 'name': 'c'
87
- }, {
88
- 'name': 'a'
89
- }, {
90
- 'name': 'd'
91
- }]
92
- }, {
93
- 'list_of_structs': [{
94
- 'name': 'd'
95
- }]
96
- }]
97
- dataset = make_test_data(items)
98
-
99
- result = dataset.select_groups(leaf_path='list_of_structs.*.name')
100
- assert result.counts == [('a', 2), ('d', 2), ('b', 1), ('c', 1)]
101
-
102
-
103
- def test_nested_lists(make_test_data: TestDataMaker) -> None:
104
- items: list[Item] = [{
105
- 'nested_list': [[{
106
- 'name': 'a'
107
- }], [{
108
- 'name': 'b'
109
- }]]
110
- }, {
111
- 'nested_list': [[{
112
- 'name': 'c'
113
- }, {
114
- 'name': 'a'
115
- }], [{
116
- 'name': 'd'
117
- }]]
118
- }, {
119
- 'nested_list': [[{
120
- 'name': 'd'
121
- }]]
122
- }]
123
- dataset = make_test_data(items)
124
-
125
- result = dataset.select_groups(leaf_path='nested_list.*.*.name')
126
- assert result.counts == [('a', 2), ('d', 2), ('b', 1), ('c', 1)]
127
-
128
-
129
- def test_nested_struct(make_test_data: TestDataMaker) -> None:
130
- items: list[Item] = [
131
- {
132
- 'nested_struct': {
133
- 'struct': {
134
- 'name': 'c'
135
- }
136
- }
137
- },
138
- {
139
- 'nested_struct': {
140
- 'struct': {
141
- 'name': 'b'
142
- }
143
- }
144
- },
145
- {
146
- 'nested_struct': {
147
- 'struct': {
148
- 'name': 'a'
149
- }
150
- }
151
- },
152
- ]
153
- dataset = make_test_data(items)
154
-
155
- result = dataset.select_groups(leaf_path='nested_struct.struct.name')
156
- assert result.counts == [('c', 1), ('b', 1), ('a', 1)]
157
-
158
-
159
- def test_named_bins(make_test_data: TestDataMaker) -> None:
160
- items: list[Item] = [{
161
- 'age': 34,
162
- }, {
163
- 'age': 45,
164
- }, {
165
- 'age': 17,
166
- }, {
167
- 'age': 80
168
- }, {
169
- 'age': 55
170
- }, {
171
- 'age': float('nan')
172
- }]
173
- dataset = make_test_data(items)
174
-
175
- result = dataset.select_groups(
176
- leaf_path='age',
177
- bins=[
178
- ('young', None, 20),
179
- ('adult', 20, 50),
180
- ('middle-aged', 50, 65),
181
- ('senior', 65, None),
182
- ])
183
- assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
184
-
185
-
186
- def test_schema_with_bins(make_test_data: TestDataMaker) -> None:
187
- items: list[Item] = [{
188
- 'age': 34,
189
- }, {
190
- 'age': 45,
191
- }, {
192
- 'age': 17,
193
- }, {
194
- 'age': 80
195
- }, {
196
- 'age': 55
197
- }, {
198
- 'age': float('nan')
199
- }]
200
- data_schema = schema({
201
- UUID_COLUMN: 'string',
202
- 'age': field(
203
- 'float32',
204
- bins=[
205
- ('young', None, 20),
206
- ('adult', 20, 50),
207
- ('middle-aged', 50, 65),
208
- ('senior', 65, None),
209
- ])
210
- })
211
- dataset = make_test_data(items, data_schema)
212
-
213
- result = dataset.select_groups(leaf_path='age')
214
- assert result.counts == [('adult', 2), ('young', 1), ('senior', 1), ('middle-aged', 1), (None, 1)]
215
-
216
-
217
- def test_filters(make_test_data: TestDataMaker) -> None:
218
- items: list[Item] = [
219
- {
220
- 'name': 'Name1',
221
- 'age': 34,
222
- 'active': False
223
- },
224
- {
225
- 'name': 'Name2',
226
- 'age': 45,
227
- 'active': True
228
- },
229
- {
230
- 'age': 17,
231
- 'active': True
232
- }, # Missing "name".
233
- {
234
- 'name': 'Name3',
235
- 'active': True
236
- }, # Missing "age".
237
- {
238
- 'name': 'Name4',
239
- 'age': 55
240
- } # Missing "active".
241
- ]
242
- dataset = make_test_data(items)
243
-
244
- # active = True.
245
- result = dataset.select_groups(leaf_path='name', filters=[('active', BinaryOp.EQUALS, True)])
246
- assert result.counts == [('Name2', 1), (None, 1), ('Name3', 1)]
247
-
248
- # age < 35.
249
- result = dataset.select_groups(leaf_path='name', filters=[('age', BinaryOp.LESS, 35)])
250
- assert result.counts == [('Name1', 1), (None, 1)]
251
-
252
- # age < 35 and active = True.
253
- result = dataset.select_groups(
254
- leaf_path='name', filters=[('age', BinaryOp.LESS, 35), ('active', BinaryOp.EQUALS, True)])
255
- assert result.counts == [(None, 1)]
256
-
257
-
258
- def test_invalid_leaf(make_test_data: TestDataMaker) -> None:
259
- items: list[Item] = [
260
- {
261
- 'nested_struct': {
262
- 'struct': {
263
- 'name': 'c'
264
- }
265
- }
266
- },
267
- {
268
- 'nested_struct': {
269
- 'struct': {
270
- 'name': 'b'
271
- }
272
- }
273
- },
274
- {
275
- 'nested_struct': {
276
- 'struct': {
277
- 'name': 'a'
278
- }
279
- }
280
- },
281
- ]
282
- dataset = make_test_data(items)
283
-
284
- with pytest.raises(
285
- ValueError, match=re.escape("Leaf \"('nested_struct',)\" not found in dataset")):
286
- dataset.select_groups(leaf_path='nested_struct')
287
-
288
- with pytest.raises(
289
- ValueError, match=re.escape("Leaf \"('nested_struct', 'struct')\" not found in dataset")):
290
- dataset.select_groups(leaf_path='nested_struct.struct')
291
-
292
- with pytest.raises(
293
- ValueError,
294
- match=re.escape("Leaf \"('nested_struct', 'struct', 'wrong_name')\" not found in dataset")):
295
- dataset.select_groups(leaf_path='nested_struct.struct.wrong_name')
296
-
297
-
298
- def test_too_many_distinct(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
299
- too_many_distinct = 5
300
- mocker.patch(f'{dataset_module.__name__}.TOO_MANY_DISTINCT', too_many_distinct)
301
-
302
- items: list[Item] = [{'feature': str(i)} for i in range(too_many_distinct + 10)]
303
- dataset = make_test_data(items)
304
-
305
- res = dataset.select_groups('feature')
306
- assert res.too_many_distinct is True
307
- assert res.counts == []
308
-
309
-
310
- def test_auto_bins_for_float(make_test_data: TestDataMaker) -> None:
311
- items: list[Item] = [{'feature': float(i)} for i in range(5)] + [{'feature': float('nan')}]
312
- dataset = make_test_data(items)
313
-
314
- res = dataset.select_groups('feature')
315
- assert res.counts == [('0', 1), ('3', 1), ('7', 1), ('11', 1), ('14', 1), (None, 1)]
316
- assert res.too_many_distinct is False
317
- assert res.bins
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_select_rows_filter_test.py DELETED
@@ -1,200 +0,0 @@
1
- """Tests for dataset.select_rows(filters=[...])."""
2
-
3
- import pytest
4
-
5
- from ..schema import UUID_COLUMN, Item, schema
6
- from .dataset import BinaryFilterTuple, BinaryOp, ListFilterTuple, ListOp, UnaryOp
7
- from .dataset_test_utils import TestDataMaker
8
-
9
- TEST_DATA: list[Item] = [{
10
- UUID_COLUMN: '1',
11
- 'str': 'a',
12
- 'int': 1,
13
- 'bool': False,
14
- 'float': 3.0
15
- }, {
16
- UUID_COLUMN: '2',
17
- 'str': 'b',
18
- 'int': 2,
19
- 'bool': True,
20
- 'float': 2.0
21
- }, {
22
- UUID_COLUMN: '3',
23
- 'str': 'b',
24
- 'int': 2,
25
- 'bool': True,
26
- 'float': 1.0
27
- }, {
28
- UUID_COLUMN: '4',
29
- 'float': float('nan')
30
- }]
31
-
32
-
33
- def test_filter_by_ids(make_test_data: TestDataMaker) -> None:
34
- dataset = make_test_data(TEST_DATA)
35
-
36
- id_filter: BinaryFilterTuple = (UUID_COLUMN, BinaryOp.EQUALS, '1')
37
- result = dataset.select_rows(filters=[id_filter])
38
-
39
- assert list(result) == [{UUID_COLUMN: '1', 'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}]
40
-
41
- id_filter = (UUID_COLUMN, BinaryOp.EQUALS, '2')
42
- result = dataset.select_rows(filters=[id_filter])
43
-
44
- assert list(result) == [{UUID_COLUMN: '2', 'str': 'b', 'int': 2, 'bool': True, 'float': 2.0}]
45
-
46
- id_filter = (UUID_COLUMN, BinaryOp.EQUALS, b'f')
47
- result = dataset.select_rows(filters=[id_filter])
48
-
49
- assert list(result) == []
50
-
51
-
52
- def test_filter_greater(make_test_data: TestDataMaker) -> None:
53
- dataset = make_test_data(TEST_DATA)
54
-
55
- id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER, 2.0)
56
- result = dataset.select_rows(filters=[id_filter])
57
-
58
- assert list(result) == [{UUID_COLUMN: '1', 'str': 'a', 'int': 1, 'bool': False, 'float': 3.0}]
59
-
60
-
61
- def test_filter_greater_equal(make_test_data: TestDataMaker) -> None:
62
- dataset = make_test_data(TEST_DATA)
63
-
64
- id_filter: BinaryFilterTuple = ('float', BinaryOp.GREATER_EQUAL, 2.0)
65
- result = dataset.select_rows(filters=[id_filter])
66
-
67
- assert list(result) == [{
68
- UUID_COLUMN: '1',
69
- 'str': 'a',
70
- 'int': 1,
71
- 'bool': False,
72
- 'float': 3.0
73
- }, {
74
- UUID_COLUMN: '2',
75
- 'str': 'b',
76
- 'int': 2,
77
- 'bool': True,
78
- 'float': 2.0
79
- }]
80
-
81
-
82
- def test_filter_less(make_test_data: TestDataMaker) -> None:
83
- dataset = make_test_data(TEST_DATA)
84
-
85
- id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS, 2.0)
86
- result = dataset.select_rows(filters=[id_filter])
87
-
88
- assert list(result) == [{UUID_COLUMN: '3', 'str': 'b', 'int': 2, 'bool': True, 'float': 1.0}]
89
-
90
-
91
- def test_filter_less_equal(make_test_data: TestDataMaker) -> None:
92
- dataset = make_test_data(TEST_DATA)
93
-
94
- id_filter: BinaryFilterTuple = ('float', BinaryOp.LESS_EQUAL, 2.0)
95
- result = dataset.select_rows(filters=[id_filter])
96
-
97
- assert list(result) == [{
98
- UUID_COLUMN: '2',
99
- 'str': 'b',
100
- 'int': 2,
101
- 'bool': True,
102
- 'float': 2.0
103
- }, {
104
- UUID_COLUMN: '3',
105
- 'str': 'b',
106
- 'int': 2,
107
- 'bool': True,
108
- 'float': 1.0
109
- }]
110
-
111
-
112
- def test_filter_not_equal(make_test_data: TestDataMaker) -> None:
113
- dataset = make_test_data(TEST_DATA)
114
-
115
- id_filter: BinaryFilterTuple = ('float', BinaryOp.NOT_EQUAL, 2.0)
116
- result = dataset.select_rows(filters=[id_filter])
117
-
118
- assert list(result) == [
119
- {
120
- UUID_COLUMN: '1',
121
- 'str': 'a',
122
- 'int': 1,
123
- 'bool': False,
124
- 'float': 3.0
125
- },
126
- {
127
- UUID_COLUMN: '3',
128
- 'str': 'b',
129
- 'int': 2,
130
- 'bool': True,
131
- 'float': 1.0
132
- },
133
- # NaNs are not counted when we are filtering a field.
134
- ]
135
-
136
-
137
- def test_filter_by_list_of_ids(make_test_data: TestDataMaker) -> None:
138
- dataset = make_test_data(TEST_DATA)
139
-
140
- id_filter: ListFilterTuple = (UUID_COLUMN, ListOp.IN, ['1', '2'])
141
- result = dataset.select_rows(filters=[id_filter])
142
-
143
- assert list(result) == [{
144
- UUID_COLUMN: '1',
145
- 'str': 'a',
146
- 'int': 1,
147
- 'bool': False,
148
- 'float': 3.0
149
- }, {
150
- UUID_COLUMN: '2',
151
- 'str': 'b',
152
- 'int': 2,
153
- 'bool': True,
154
- 'float': 2.0
155
- }]
156
-
157
-
158
- def test_filter_by_exists(make_test_data: TestDataMaker) -> None:
159
- items: list[Item] = [{
160
- UUID_COLUMN: '1',
161
- 'name': 'A',
162
- 'info': {
163
- 'lang': 'en'
164
- },
165
- 'ages': []
166
- }, {
167
- UUID_COLUMN: '2',
168
- 'info': {
169
- 'lang': 'fr'
170
- },
171
- }, {
172
- UUID_COLUMN: '3',
173
- 'name': 'C',
174
- 'ages': [[1, 2], [3, 4]]
175
- }]
176
- dataset = make_test_data(
177
- items,
178
- schema=schema({
179
- UUID_COLUMN: 'string',
180
- 'name': 'string',
181
- 'info': {
182
- 'lang': 'string'
183
- },
184
- 'ages': [['int32']]
185
- }))
186
-
187
- exists_filter = ('name', UnaryOp.EXISTS)
188
- result = dataset.select_rows(['name'], filters=[exists_filter])
189
- assert list(result) == [{UUID_COLUMN: '1', 'name': 'A'}, {UUID_COLUMN: '3', 'name': 'C'}]
190
-
191
- exists_filter = ('info.lang', UnaryOp.EXISTS)
192
- result = dataset.select_rows(['name'], filters=[exists_filter])
193
- assert list(result) == [{UUID_COLUMN: '1', 'name': 'A'}, {UUID_COLUMN: '2', 'name': None}]
194
-
195
- exists_filter = ('ages.*.*', UnaryOp.EXISTS)
196
- result = dataset.select_rows(['name'], filters=[exists_filter])
197
- assert list(result) == [{UUID_COLUMN: '3', 'name': 'C'}]
198
-
199
- with pytest.raises(ValueError, match='Unable to filter on path'):
200
- dataset.select_rows(['name'], filters=[('info', UnaryOp.EXISTS)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_select_rows_schema_test.py DELETED
@@ -1,551 +0,0 @@
1
- """Tests for `db.select_rows_schema()`."""
2
-
3
- from typing import Iterable, Optional, cast
4
-
5
- import numpy as np
6
- import pytest
7
- from typing_extensions import override
8
-
9
- from ..embeddings.vector_store import VectorStore
10
- from ..schema import PATH_WILDCARD, UUID_COLUMN, Field, Item, RichData, VectorKey, field, schema
11
- from ..signals.concept_labels import ConceptLabelsSignal
12
- from ..signals.concept_scorer import ConceptScoreSignal
13
- from ..signals.semantic_similarity import SemanticSimilaritySignal
14
- from ..signals.signal import (
15
- EMBEDDING_KEY,
16
- TextEmbeddingModelSignal,
17
- TextEmbeddingSignal,
18
- TextSignal,
19
- TextSplitterSignal,
20
- clear_signal_registry,
21
- register_signal,
22
- )
23
- from ..signals.substring_search import SubstringSignal
24
- from .dataset import (
25
- Column,
26
- ConceptQuery,
27
- KeywordQuery,
28
- Search,
29
- SearchResultInfo,
30
- SelectRowsSchemaResult,
31
- SelectRowsSchemaUDF,
32
- SemanticQuery,
33
- SortOrder,
34
- SortResult,
35
- )
36
- from .dataset_test_utils import (
37
- TEST_DATASET_NAME,
38
- TEST_NAMESPACE,
39
- TestDataMaker,
40
- enriched_embedding_span_field,
41
- )
42
- from .dataset_utils import lilac_embedding, lilac_span
43
-
44
- TEST_DATA: list[Item] = [{
45
- UUID_COLUMN: '1',
46
- 'erased': False,
47
- 'people': [{
48
- 'name': 'A',
49
- 'zipcode': 0,
50
- 'locations': [{
51
- 'city': 'city1',
52
- 'state': 'state1'
53
- }, {
54
- 'city': 'city2',
55
- 'state': 'state2'
56
- }]
57
- }]
58
- }, {
59
- UUID_COLUMN: '2',
60
- 'erased': True,
61
- 'people': [{
62
- 'name': 'B',
63
- 'zipcode': 1,
64
- 'locations': [{
65
- 'city': 'city3',
66
- 'state': 'state3'
67
- }, {
68
- 'city': 'city4'
69
- }, {
70
- 'city': 'city5'
71
- }]
72
- }, {
73
- 'name': 'C',
74
- 'zipcode': 2,
75
- 'locations': [{
76
- 'city': 'city1',
77
- 'state': 'state1'
78
- }]
79
- }]
80
- }]
81
-
82
-
83
- class TestSplitter(TextSplitterSignal):
84
- """Split documents into sentence by splitting on period."""
85
- name = 'test_splitter'
86
-
87
- @override
88
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
89
- for text in data:
90
- if not isinstance(text, str):
91
- raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
92
- sentences = [f'{sentence.strip()}.' for sentence in text.split('.') if sentence]
93
- yield [
94
- lilac_span(text.index(sentence),
95
- text.index(sentence) + len(sentence)) for sentence in sentences
96
- ]
97
-
98
-
99
- EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
100
- ('hello2.', [1.0, 1.0, 0.0]),
101
- ('hello world.', [1.0, 1.0, 1.0]),
102
- ('hello world2.', [2.0, 1.0, 1.0])]
103
-
104
- STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
105
-
106
-
107
- class TestEmbedding(TextEmbeddingSignal):
108
- """A test embed function."""
109
- name = 'test_embedding'
110
-
111
- @override
112
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
113
- """Call the embedding function."""
114
- for example in data:
115
- yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
116
-
117
-
118
- class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
119
- """Sums the embeddings to return a single floating point value."""
120
- name = 'test_embedding_sum'
121
-
122
- @override
123
- def fields(self) -> Field:
124
- return field('float32')
125
-
126
- @override
127
- def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
128
- # The signal just sums the values of the embedding.
129
- embedding_sums = vector_store.get(keys).sum(axis=1)
130
- for embedding_sum in embedding_sums.tolist():
131
- yield embedding_sum
132
-
133
-
134
- @pytest.fixture(scope='module', autouse=True)
135
- def setup_teardown() -> Iterable[None]:
136
- # Setup.
137
- register_signal(LengthSignal)
138
- register_signal(AddSpaceSignal)
139
- register_signal(TestSplitter)
140
- register_signal(TestEmbedding)
141
- register_signal(TestEmbeddingSumSignal)
142
-
143
- # Unit test runs.
144
- yield
145
-
146
- # Teardown.
147
- clear_signal_registry()
148
-
149
-
150
- class LengthSignal(TextSignal):
151
- name = 'length_signal'
152
-
153
- def fields(self) -> Field:
154
- return field('int32')
155
-
156
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
157
- for text_content in data:
158
- yield len(text_content)
159
-
160
-
161
- class AddSpaceSignal(TextSignal):
162
- name = 'add_space_signal'
163
-
164
- def fields(self) -> Field:
165
- return field('string')
166
-
167
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
168
- for text_content in data:
169
- yield cast(str, text_content) + ' '
170
-
171
-
172
- def test_simple_schema(make_test_data: TestDataMaker) -> None:
173
- dataset = make_test_data(TEST_DATA)
174
- result = dataset.select_rows_schema(combine_columns=True)
175
- assert result == SelectRowsSchemaResult(
176
- data_schema=schema({
177
- UUID_COLUMN: 'string',
178
- 'erased': 'boolean',
179
- 'people': [{
180
- 'name': 'string',
181
- 'zipcode': 'int32',
182
- 'locations': [{
183
- 'city': 'string',
184
- 'state': 'string'
185
- }]
186
- }]
187
- }))
188
-
189
-
190
- def test_subselection_with_combine_cols(make_test_data: TestDataMaker) -> None:
191
- dataset = make_test_data(TEST_DATA)
192
-
193
- result = dataset.select_rows_schema([('people', '*', 'zipcode'),
194
- ('people', '*', 'locations', '*', 'city')],
195
- combine_columns=True)
196
- assert result == SelectRowsSchemaResult(
197
- data_schema=schema({
198
- UUID_COLUMN: 'string',
199
- 'people': [{
200
- 'zipcode': 'int32',
201
- 'locations': [{
202
- 'city': 'string'
203
- }]
204
- }]
205
- }))
206
-
207
- result = dataset.select_rows_schema([('people', '*', 'name'), ('people', '*', 'locations')],
208
- combine_columns=True)
209
- assert result == SelectRowsSchemaResult(
210
- data_schema=schema({
211
- UUID_COLUMN: 'string',
212
- 'people': [{
213
- 'name': 'string',
214
- 'locations': [{
215
- 'city': 'string',
216
- 'state': 'string'
217
- }]
218
- }]
219
- }))
220
-
221
- result = dataset.select_rows_schema([('people', '*')], combine_columns=True)
222
- assert result == SelectRowsSchemaResult(
223
- namespace=TEST_NAMESPACE,
224
- dataset_name=TEST_DATASET_NAME,
225
- data_schema=schema({
226
- UUID_COLUMN: 'string',
227
- 'people': [{
228
- 'name': 'string',
229
- 'zipcode': 'int32',
230
- 'locations': [{
231
- 'city': 'string',
232
- 'state': 'string'
233
- }]
234
- }]
235
- }))
236
-
237
-
238
- def test_udf_with_combine_cols(make_test_data: TestDataMaker) -> None:
239
- dataset = make_test_data(TEST_DATA)
240
-
241
- length_signal = LengthSignal()
242
- result = dataset.select_rows_schema([('people', '*', 'locations', '*', 'city'),
243
- Column(('people', '*', 'name'), signal_udf=length_signal)],
244
- combine_columns=True)
245
- assert result == SelectRowsSchemaResult(
246
- data_schema=schema({
247
- UUID_COLUMN: 'string',
248
- 'people': [{
249
- 'name': {
250
- 'length_signal': field('int32', length_signal.dict())
251
- },
252
- 'locations': [{
253
- 'city': 'string'
254
- }]
255
- }],
256
- }),
257
- udfs=[
258
- SelectRowsSchemaUDF(path=('people', '*', 'name', length_signal.key())),
259
- ],
260
- )
261
-
262
-
263
- def test_embedding_udf_with_combine_cols(make_test_data: TestDataMaker) -> None:
264
- dataset = make_test_data(TEST_DATA)
265
-
266
- add_space_signal = AddSpaceSignal()
267
- path = ('people', '*', 'name')
268
- dataset.compute_signal(add_space_signal, path)
269
- result = dataset.select_rows_schema([path, Column(path, signal_udf=add_space_signal)],
270
- combine_columns=True)
271
- assert result == SelectRowsSchemaResult(
272
- data_schema=schema({
273
- UUID_COLUMN: 'string',
274
- 'people': [{
275
- 'name': field(
276
- 'string', fields={'add_space_signal': field('string', signal=add_space_signal.dict())})
277
- }],
278
- }),
279
- udfs=[
280
- SelectRowsSchemaUDF(path=(*path, add_space_signal.key())),
281
- ],
282
- )
283
-
284
-
285
- def test_udf_chained_with_combine_cols(make_test_data: TestDataMaker) -> None:
286
- dataset = make_test_data([{
287
- UUID_COLUMN: '1',
288
- 'text': 'hello. hello2.',
289
- }, {
290
- UUID_COLUMN: '2',
291
- 'text': 'hello world. hello world2.',
292
- }])
293
-
294
- test_splitter = TestSplitter()
295
- dataset.compute_signal(test_splitter, ('text'))
296
- add_space_signal = AddSpaceSignal()
297
- result = dataset.select_rows_schema(
298
- [('text'), Column(('text'), signal_udf=add_space_signal)], combine_columns=True)
299
-
300
- assert result == SelectRowsSchemaResult(
301
- data_schema=schema({
302
- UUID_COLUMN: 'string',
303
- 'text': field(
304
- 'string',
305
- fields={
306
- 'add_space_signal': field('string', add_space_signal.dict()),
307
- 'test_splitter': field(signal=test_splitter.dict(), fields=['string_span'])
308
- })
309
- }),
310
- udfs=[
311
- SelectRowsSchemaUDF(path=('text', add_space_signal.key())),
312
- ],
313
- )
314
-
315
-
316
- def test_udf_embedding_chained_with_combine_cols(make_test_data: TestDataMaker) -> None:
317
- dataset = make_test_data([{
318
- UUID_COLUMN: '1',
319
- 'text': 'hello. hello2.',
320
- }, {
321
- UUID_COLUMN: '2',
322
- 'text': 'hello world. hello world2.',
323
- }])
324
-
325
- test_splitter = TestSplitter()
326
- dataset.compute_signal(test_splitter, 'text')
327
- test_embedding = TestEmbedding()
328
- dataset.compute_signal(test_embedding, ('text', 'test_splitter', '*'))
329
-
330
- embedding_sum_signal = TestEmbeddingSumSignal(embedding='test_embedding')
331
- udf_col = Column(('text', 'test_splitter', '*'), signal_udf=embedding_sum_signal)
332
- result = dataset.select_rows_schema([('text'), udf_col], combine_columns=True)
333
-
334
- expected_schema = schema({
335
- UUID_COLUMN: 'string',
336
- 'text': field(
337
- 'string',
338
- fields={
339
- 'test_splitter': field(
340
- signal=test_splitter.dict(),
341
- fields=[
342
- field(
343
- 'string_span',
344
- fields={
345
- 'test_embedding': field(
346
- signal=test_embedding.dict(),
347
- fields=[
348
- enriched_embedding_span_field(
349
- {'test_embedding_sum': field('float32', embedding_sum_signal.dict())})
350
- ])
351
- })
352
- ])
353
- })
354
- })
355
- output_path = ('text', 'test_splitter', '*', 'test_embedding', '*', 'embedding',
356
- 'test_embedding_sum')
357
- assert result == SelectRowsSchemaResult(
358
- data_schema=expected_schema,
359
- udfs=[SelectRowsSchemaUDF(path=output_path)],
360
- )
361
-
362
- # Alias the udf.
363
- udf_col.alias = 'udf1'
364
- result = dataset.select_rows_schema([('text'), udf_col], combine_columns=True)
365
- assert result == SelectRowsSchemaResult(
366
- data_schema=expected_schema,
367
- udfs=[SelectRowsSchemaUDF(path=output_path, alias='udf1')],
368
- )
369
-
370
-
371
- def test_search_keyword_schema(make_test_data: TestDataMaker) -> None:
372
- dataset = make_test_data([{
373
- UUID_COLUMN: '1',
374
- 'text': 'hello world',
375
- 'text2': 'hello world2',
376
- }])
377
- query_world = 'world'
378
- query_hello = 'hello'
379
-
380
- result = dataset.select_rows_schema(
381
- searches=[
382
- Search(path='text', query=KeywordQuery(type='keyword', search=query_world)),
383
- Search(path='text2', query=KeywordQuery(type='keyword', search=query_hello)),
384
- ],
385
- combine_columns=True)
386
-
387
- expected_world_signal = SubstringSignal(query=query_world)
388
- expected_hello_signal = SubstringSignal(query=query_hello)
389
-
390
- assert result == SelectRowsSchemaResult(
391
- data_schema=schema({
392
- UUID_COLUMN: 'string',
393
- 'text': field(
394
- 'string',
395
- fields={
396
- expected_world_signal.key(): field(
397
- signal=expected_world_signal.dict(), fields=['string_span'])
398
- }),
399
- 'text2': field(
400
- 'string',
401
- fields={
402
- expected_hello_signal.key(): field(
403
- signal=expected_hello_signal.dict(), fields=['string_span'])
404
- })
405
- }),
406
- search_results=[
407
- SearchResultInfo(
408
- search_path=('text',),
409
- result_path=('text', expected_world_signal.key(), PATH_WILDCARD),
410
- ),
411
- SearchResultInfo(
412
- search_path=('text2',),
413
- result_path=('text2', expected_hello_signal.key(), PATH_WILDCARD),
414
- )
415
- ],
416
- udfs=[
417
- SelectRowsSchemaUDF(path=('text', expected_world_signal.key())),
418
- SelectRowsSchemaUDF(path=('text2', expected_hello_signal.key())),
419
- ],
420
- )
421
-
422
-
423
- def test_search_semantic_schema(make_test_data: TestDataMaker) -> None:
424
- dataset = make_test_data([{
425
- UUID_COLUMN: '1',
426
- 'text': 'hello world.',
427
- }])
428
- query_world = 'world'
429
-
430
- test_embedding = TestEmbedding()
431
- dataset.compute_signal(test_embedding, ('text'))
432
-
433
- result = dataset.select_rows_schema(
434
- searches=[
435
- Search(
436
- path='text',
437
- query=SemanticQuery(type='semantic', search=query_world, embedding='test_embedding')),
438
- ],
439
- combine_columns=True)
440
-
441
- test_embedding = TestEmbedding()
442
- expected_world_signal = SemanticSimilaritySignal(query=query_world, embedding='test_embedding')
443
-
444
- similarity_score_path = ('text', 'test_embedding', PATH_WILDCARD, EMBEDDING_KEY,
445
- expected_world_signal.key())
446
- assert result == SelectRowsSchemaResult(
447
- data_schema=schema({
448
- UUID_COLUMN: 'string',
449
- 'text': field(
450
- 'string',
451
- fields={
452
- 'test_embedding': field(
453
- signal=test_embedding.dict(),
454
- fields=[
455
- enriched_embedding_span_field(
456
- {expected_world_signal.key(): field('float32', expected_world_signal.dict())})
457
- ])
458
- })
459
- }),
460
- udfs=[SelectRowsSchemaUDF(path=similarity_score_path)],
461
- search_results=[SearchResultInfo(search_path=('text',), result_path=similarity_score_path)],
462
- sorts=[SortResult(path=similarity_score_path, order=SortOrder.DESC, search_index=0)])
463
-
464
-
465
- def test_search_concept_schema(make_test_data: TestDataMaker) -> None:
466
- dataset = make_test_data([{
467
- UUID_COLUMN: '1',
468
- 'text': 'hello world.',
469
- }])
470
-
471
- test_embedding = TestEmbedding()
472
- dataset.compute_signal(test_embedding, ('text'))
473
-
474
- result = dataset.select_rows_schema(
475
- searches=[
476
- Search(
477
- path='text',
478
- query=ConceptQuery(
479
- type='concept',
480
- concept_namespace='test_namespace',
481
- concept_name='test_concept',
482
- embedding='test_embedding')),
483
- ],
484
- combine_columns=True)
485
-
486
- test_embedding = TestEmbedding()
487
- expected_world_signal = ConceptScoreSignal(
488
- namespace='test_namespace', concept_name='test_concept', embedding='test_embedding')
489
- expected_labels_signal = ConceptLabelsSignal(
490
- namespace='test_namespace', concept_name='test_concept')
491
-
492
- concept_score_path = ('text', 'test_embedding', PATH_WILDCARD, EMBEDDING_KEY,
493
- expected_world_signal.key())
494
- concept_labels_path = ('text', expected_labels_signal.key())
495
- assert result == SelectRowsSchemaResult(
496
- data_schema=schema({
497
- UUID_COLUMN: 'string',
498
- 'text': field(
499
- 'string',
500
- fields={
501
- 'test_embedding': field(
502
- signal=test_embedding.dict(),
503
- fields=[
504
- enriched_embedding_span_field({
505
- expected_world_signal.key(): field(
506
- 'float32',
507
- expected_world_signal.dict(),
508
- bins=[('Not in concept', None, 0.5), ('In concept', 0.5, None)])
509
- })
510
- ]),
511
- 'test_namespace/test_concept/labels': field(
512
- fields=[field('string_span', fields={
513
- 'label': 'boolean',
514
- 'draft': 'string'
515
- })],
516
- signal=expected_labels_signal.dict())
517
- })
518
- }),
519
- udfs=[
520
- SelectRowsSchemaUDF(path=concept_labels_path),
521
- SelectRowsSchemaUDF(path=concept_score_path)
522
- ],
523
- search_results=[
524
- SearchResultInfo(search_path=('text',), result_path=concept_labels_path),
525
- SearchResultInfo(search_path=('text',), result_path=concept_score_path)
526
- ],
527
- sorts=[SortResult(path=concept_score_path, order=SortOrder.DESC, search_index=0)])
528
-
529
-
530
- def test_search_sort_override(make_test_data: TestDataMaker) -> None:
531
- dataset = make_test_data([{
532
- UUID_COLUMN: '1',
533
- 'text': 'hello world.',
534
- }])
535
- query_world = 'world'
536
-
537
- test_embedding = TestEmbedding()
538
- dataset.compute_signal(test_embedding, ('text'))
539
-
540
- result = dataset.select_rows_schema(
541
- searches=[
542
- Search(
543
- path='text',
544
- query=SemanticQuery(type='semantic', search=query_world, embedding='test_embedding')),
545
- ],
546
- # Explicit sort by overrides the semantic search.
547
- sort_by=[('text',)],
548
- sort_order=SortOrder.DESC,
549
- combine_columns=True)
550
-
551
- assert result.sorts == [SortResult(path=('text',), order=SortOrder.DESC)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_select_rows_search_test.py DELETED
@@ -1,393 +0,0 @@
1
- """Tests for dataset.select_rows(searches=[...])."""
2
-
3
- from typing import Iterable, cast
4
-
5
- import numpy as np
6
- import pytest
7
- from pytest import approx
8
- from pytest_mock import MockerFixture
9
- from sklearn.preprocessing import normalize
10
- from typing_extensions import override
11
-
12
- from ..concepts.concept import ExampleIn, LogisticEmbeddingModel
13
- from ..concepts.db_concept import ConceptUpdate, DiskConceptDB
14
- from ..db_manager import set_default_dataset_cls
15
- from ..schema import UUID_COLUMN, Item, RichData, SignalInputType
16
- from ..signals.concept_scorer import ConceptScoreSignal
17
- from ..signals.semantic_similarity import SemanticSimilaritySignal
18
- from ..signals.signal import TextEmbeddingSignal, clear_signal_registry, register_signal
19
- from ..signals.substring_search import SubstringSignal
20
- from .dataset import ConceptQuery, KeywordQuery, ListOp, Search, SemanticQuery, SortOrder
21
- from .dataset_duckdb import DatasetDuckDB
22
- from .dataset_test_utils import TestDataMaker, enriched_embedding_span, enriched_item
23
- from .dataset_utils import lilac_embedding, lilac_span
24
-
25
- TEST_DATA: list[Item] = [{
26
- UUID_COLUMN: '1',
27
- 'text': 'hello world',
28
- 'text2': 'again hello world',
29
- }, {
30
- UUID_COLUMN: '2',
31
- 'text': 'looking for world in text',
32
- 'text2': 'again looking for world in text',
33
- }, {
34
- UUID_COLUMN: '3',
35
- 'text': 'unrelated text',
36
- 'text2': 'again unrelated text'
37
- }]
38
-
39
- EMBEDDINGS: list[tuple[str, list[float]]] = [
40
- ('hello.', [1.0, 0.0, 0.0]),
41
- ('hello2.', [1.0, 1.0, 0.0]),
42
- ('hello world.', [1.0, 1.0, 1.0]),
43
- ('hello world2.', [2.0, 1.0, 1.0]),
44
- ('random negative 1', [0, 0, 0.3]),
45
- ('random negative 2', [0, 0, 0.4]),
46
- ('random negative 3', [0, 0.1, 0.5]),
47
- ('random negative 4', [0.1, 0, 0.4]),
48
- ]
49
-
50
- STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
51
-
52
-
53
- @pytest.fixture(scope='module', autouse=True)
54
- def setup_teardown() -> Iterable[None]:
55
- # Setup.
56
- set_default_dataset_cls(DatasetDuckDB)
57
- register_signal(TestEmbedding)
58
-
59
- # Unit test runs.
60
- yield
61
-
62
- # Teardown.
63
- clear_signal_registry()
64
-
65
-
66
- def test_search_keyword(make_test_data: TestDataMaker) -> None:
67
- dataset = make_test_data(TEST_DATA)
68
-
69
- query = 'world'
70
- result = dataset.select_rows(
71
- searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
72
- combine_columns=True)
73
-
74
- expected_signal_udf = SubstringSignal(query=query)
75
- assert list(result) == [{
76
- UUID_COLUMN: '1',
77
- 'text': enriched_item('hello world', {expected_signal_udf.key(): [lilac_span(6, 11)]}),
78
- 'text2': 'again hello world'
79
- }, {
80
- UUID_COLUMN: '2',
81
- 'text': enriched_item('looking for world in text',
82
- {expected_signal_udf.key(): [lilac_span(12, 17)]}),
83
- 'text2': 'again looking for world in text',
84
- }]
85
-
86
-
87
- def test_search_keyword_special_chars(make_test_data: TestDataMaker) -> None:
88
- dataset = make_test_data([{
89
- UUID_COLUMN: '1',
90
- 'text': 'This is 100%',
91
- }, {
92
- UUID_COLUMN: '2',
93
- 'text': 'This has _underscore_',
94
- }])
95
-
96
- query = '100%'
97
- result = dataset.select_rows(
98
- searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
99
- combine_columns=True)
100
-
101
- expected_signal_udf = SubstringSignal(query=query)
102
- assert list(result) == [{
103
- UUID_COLUMN: '1',
104
- 'text': enriched_item('This is 100%', {expected_signal_udf.key(): [lilac_span(8, 12)]}),
105
- }]
106
-
107
- query = '_underscore_'
108
- result = dataset.select_rows(
109
- searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
110
- combine_columns=True)
111
-
112
- expected_signal_udf = SubstringSignal(query=query)
113
- assert list(result) == [{
114
- UUID_COLUMN: '2',
115
- 'text': enriched_item('This has _underscore_',
116
- {expected_signal_udf.key(): [lilac_span(9, 21)]}),
117
- }]
118
-
119
-
120
- def test_search_keyword_multiple(make_test_data: TestDataMaker) -> None:
121
- dataset = make_test_data(TEST_DATA)
122
-
123
- query_world = 'world'
124
- query_looking_world = 'looking for world'
125
- expected_world_udf = SubstringSignal(query=query_world)
126
- expected_again_looking_udf = SubstringSignal(query=query_looking_world)
127
-
128
- result = dataset.select_rows(
129
- searches=[
130
- Search(path='text', query=KeywordQuery(type='keyword', search=query_world)),
131
- Search(path='text2', query=KeywordQuery(type='keyword', search=query_looking_world)),
132
- ],
133
- combine_columns=True)
134
-
135
- assert list(result) == [{
136
- UUID_COLUMN: '2',
137
- 'text': enriched_item('looking for world in text', {
138
- expected_world_udf.key(): [lilac_span(12, 17)],
139
- }),
140
- 'text2': enriched_item('again looking for world in text',
141
- {expected_again_looking_udf.key(): [lilac_span(6, 23)]})
142
- }]
143
-
144
-
145
- def test_search_keyword_with_filters(make_test_data: TestDataMaker) -> None:
146
- dataset = make_test_data(TEST_DATA)
147
-
148
- query = 'world'
149
- result = dataset.select_rows(
150
- filters=[(UUID_COLUMN, ListOp.IN, ['1', '3'])],
151
- searches=[Search(path='text', query=KeywordQuery(type='keyword', search=query))],
152
- combine_columns=True)
153
-
154
- expected_signal_udf = SubstringSignal(query=query)
155
- assert list(result) == [
156
- {
157
- UUID_COLUMN: '1',
158
- 'text': enriched_item('hello world', {expected_signal_udf.key(): [lilac_span(6, 11)]}),
159
- 'text2': 'again hello world'
160
- },
161
- # The second row doesn't match the UUID filter.
162
- ]
163
-
164
-
165
- class TestEmbedding(TextEmbeddingSignal):
166
- """A test embed function."""
167
- name = 'test_embedding'
168
-
169
- @override
170
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
171
- """Call the embedding function."""
172
- for example in data:
173
- embedding = np.array(STR_EMBEDDINGS[cast(str, example)])
174
- embedding = normalize([embedding])[0]
175
- yield [lilac_embedding(0, len(example), embedding)]
176
-
177
-
178
- def test_semantic_search(make_test_data: TestDataMaker) -> None:
179
- dataset = make_test_data([{
180
- UUID_COLUMN: '1',
181
- 'text': 'hello world.',
182
- }, {
183
- UUID_COLUMN: '2',
184
- 'text': 'hello world2.',
185
- }])
186
-
187
- test_embedding = TestEmbedding()
188
- dataset.compute_signal(test_embedding, ('text'))
189
-
190
- query = 'hello2.'
191
- result = dataset.select_rows(
192
- searches=[
193
- Search(
194
- path='text', query=SemanticQuery(type='semantic', search=query, embedding='test_embedding'))
195
- ],
196
- combine_columns=True)
197
- expected_signal_udf = SemanticSimilaritySignal(query=query, embedding='test_embedding')
198
- assert list(result) == [
199
- # Results are sorted by score desc.
200
- {
201
- UUID_COLUMN: '2',
202
- 'text': enriched_item(
203
- 'hello world2.', {
204
- test_embedding.key():
205
- [enriched_embedding_span(0, 13, {expected_signal_udf.key(): approx(0.916, 1e-3)})]
206
- })
207
- },
208
- {
209
- UUID_COLUMN: '1',
210
- 'text': enriched_item(
211
- 'hello world.', {
212
- test_embedding.key():
213
- [enriched_embedding_span(0, 12, {expected_signal_udf.key(): approx(0.885, 1e-3)})]
214
- })
215
- },
216
- ]
217
-
218
-
219
- def test_concept_search(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
220
- concept_model_mock = mocker.spy(LogisticEmbeddingModel, 'fit')
221
-
222
- dataset = make_test_data([{
223
- UUID_COLUMN: '1',
224
- 'text': 'hello world.',
225
- }, {
226
- UUID_COLUMN: '2',
227
- 'text': 'hello world2.',
228
- }, {
229
- UUID_COLUMN: '3',
230
- 'text': 'random negative 1',
231
- }, {
232
- UUID_COLUMN: '4',
233
- 'text': 'random negative 2',
234
- }, {
235
- UUID_COLUMN: '5',
236
- 'text': 'random negative 3',
237
- }, {
238
- UUID_COLUMN: '6',
239
- 'text': 'random negative 4',
240
- }])
241
-
242
- test_embedding = TestEmbedding()
243
- dataset.compute_signal(test_embedding, ('text'))
244
-
245
- concept_db = DiskConceptDB()
246
- concept_db.create(namespace='test_namespace', name='test_concept', type=SignalInputType.TEXT)
247
- concept_db.edit(
248
- 'test_namespace', 'test_concept',
249
- ConceptUpdate(insert=[
250
- ExampleIn(label=False, text='hello world.'),
251
- ExampleIn(label=True, text='hello world2.')
252
- ]))
253
-
254
- result = dataset.select_rows(
255
- searches=[
256
- Search(
257
- path='text',
258
- query=ConceptQuery(
259
- type='concept',
260
- concept_namespace='test_namespace',
261
- concept_name='test_concept',
262
- embedding='test_embedding'))
263
- ],
264
- filters=[(UUID_COLUMN, ListOp.IN, ['1', '2'])],
265
- combine_columns=True)
266
- expected_signal_udf = ConceptScoreSignal(
267
- namespace='test_namespace', concept_name='test_concept', embedding='test_embedding')
268
-
269
- assert list(result) == [
270
- # Results are sorted by score desc.
271
- {
272
- UUID_COLUMN: '2',
273
- 'text': enriched_item(
274
- 'hello world2.', {
275
- test_embedding.key():
276
- [enriched_embedding_span(0, 13, {expected_signal_udf.key(): approx(0.75, abs=0.25)})],
277
- 'test_namespace/test_concept/labels': [lilac_span(0, 13, {'label': True})]
278
- })
279
- },
280
- {
281
- UUID_COLUMN: '1',
282
- 'text': enriched_item(
283
- 'hello world.', {
284
- test_embedding.key():
285
- [enriched_embedding_span(0, 12, {expected_signal_udf.key(): approx(0.25, abs=0.25)})],
286
- 'test_namespace/test_concept/labels': [lilac_span(0, 12, {'label': False})]
287
- })
288
- },
289
- ]
290
-
291
- (_, embeddings, labels, _) = concept_model_mock.call_args_list[-1].args
292
- assert embeddings.shape == (2, 3)
293
- assert labels == [
294
- # Explicit labels.
295
- False,
296
- True
297
- ]
298
-
299
-
300
- def test_sort_override_search(make_test_data: TestDataMaker) -> None:
301
- dataset = make_test_data([{
302
- UUID_COLUMN: '1',
303
- 'text': 'hello world.',
304
- 'value': 10
305
- }, {
306
- UUID_COLUMN: '2',
307
- 'text': 'hello world2.',
308
- 'value': 20
309
- }])
310
-
311
- test_embedding = TestEmbedding()
312
- dataset.compute_signal(test_embedding, ('text'))
313
-
314
- query = 'hello2.'
315
- search = Search(
316
- path='text', query=SemanticQuery(type='semantic', search=query, embedding='test_embedding'))
317
-
318
- expected_signal_udf = SemanticSimilaritySignal(query=query, embedding='test_embedding')
319
- expected_item_1 = {
320
- UUID_COLUMN: '1',
321
- 'text': enriched_item(
322
- 'hello world.', {
323
- test_embedding.key():
324
- [enriched_embedding_span(0, 12, {expected_signal_udf.key(): approx(0.885, 1e-3)})]
325
- }),
326
- 'value': 10
327
- }
328
- expected_item_2 = {
329
- UUID_COLUMN: '2',
330
- 'text': enriched_item(
331
- 'hello world2.', {
332
- test_embedding.key():
333
- [enriched_embedding_span(0, 13, {expected_signal_udf.key(): approx(0.916, 1e-3)})]
334
- }),
335
- 'value': 20
336
- }
337
-
338
- sort_order = SortOrder.ASC
339
- result = dataset.select_rows(
340
- searches=[search], sort_by=[('value',)], sort_order=sort_order, combine_columns=True)
341
- assert list(result) == [
342
- # Results are sorted by score ascending.
343
- expected_item_1,
344
- expected_item_2
345
- ]
346
-
347
- sort_order = SortOrder.DESC
348
- result = dataset.select_rows(
349
- searches=[search], sort_by=[('text',)], sort_order=sort_order, combine_columns=True)
350
- assert list(result) == [
351
- # Results are sorted by score descending.
352
- expected_item_2,
353
- expected_item_1
354
- ]
355
-
356
-
357
- def test_search_keyword_and_semantic(make_test_data: TestDataMaker) -> None:
358
- dataset = make_test_data([{
359
- UUID_COLUMN: '1',
360
- 'text': 'hello world.',
361
- }, {
362
- UUID_COLUMN: '2',
363
- 'text': 'hello world2.',
364
- }])
365
-
366
- test_embedding = TestEmbedding()
367
- dataset.compute_signal(test_embedding, ('text'))
368
-
369
- query = 'hello2.'
370
- keyword_query = 'rld2'
371
- result = dataset.select_rows(
372
- searches=[
373
- Search(
374
- path='text', query=SemanticQuery(type='semantic', search=query,
375
- embedding='test_embedding')),
376
- Search(path='text', query=KeywordQuery(type='keyword', search=keyword_query))
377
- ],
378
- combine_columns=True)
379
- expected_semantic_signal = SemanticSimilaritySignal(query=query, embedding='test_embedding')
380
- expected_keyword_signal = SubstringSignal(query=keyword_query)
381
- assert list(result) == [
382
- # Results are sorted by score desc.
383
- {
384
- UUID_COLUMN: '2',
385
- 'text': enriched_item(
386
- 'hello world2.', {
387
- test_embedding.key():
388
- [enriched_embedding_span(0, 13, {expected_semantic_signal.key(): approx(0.916, 1e-3)})],
389
- expected_keyword_signal.key(): [lilac_span(8, 12)],
390
- })
391
- },
392
- # UUID '1' is not returned because it does not match the keyword query.
393
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_select_rows_sort_test.py DELETED
@@ -1,904 +0,0 @@
1
- """Tests for dataset.select_rows(sort_by=...)."""
2
-
3
- from typing import Iterable, Optional, Sequence, cast
4
-
5
- import numpy as np
6
- import pytest
7
- from typing_extensions import override
8
-
9
- from ..embeddings.vector_store import VectorStore
10
- from ..schema import UUID_COLUMN, Field, Item, RichData, VectorKey, field
11
- from ..signals.signal import (
12
- TextEmbeddingModelSignal,
13
- TextEmbeddingSignal,
14
- TextSignal,
15
- clear_signal_registry,
16
- register_signal,
17
- )
18
- from .dataset import BinaryOp, Column, SortOrder
19
- from .dataset_test_utils import TestDataMaker, enriched_item
20
- from .dataset_utils import lilac_embedding
21
-
22
-
23
- class TestSignal(TextSignal):
24
- name = 'test_signal'
25
-
26
- def fields(self) -> Field:
27
- return field(fields={'len': 'int32', 'is_all_cap': 'boolean'})
28
-
29
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
30
- for text_content in data:
31
- yield {'len': len(text_content), 'is_all_cap': text_content.isupper()}
32
-
33
-
34
- class TestPrimitiveSignal(TextSignal):
35
- name = 'primitive_signal'
36
-
37
- def fields(self) -> Field:
38
- return field('int32')
39
-
40
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
41
- for text_content in data:
42
- yield len(text_content) + 1
43
-
44
-
45
- class NestedArraySignal(TextSignal):
46
- name = 'nested_array'
47
-
48
- def fields(self) -> Field:
49
- return field(fields=[['int32']])
50
-
51
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
52
- for text_content in data:
53
- yield [[len(text_content) + 1], [len(text_content)]]
54
-
55
-
56
- @pytest.fixture(scope='module', autouse=True)
57
- def setup_teardown() -> Iterable[None]:
58
- # Setup.
59
- register_signal(TestSignal)
60
- register_signal(TestPrimitiveSignal)
61
- register_signal(NestedArraySignal)
62
- register_signal(TopKEmbedding)
63
- # Unit test runs.
64
- yield
65
- # Teardown.
66
- clear_signal_registry()
67
-
68
-
69
- def test_sort_by_source_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
70
- dataset = make_test_data([{
71
- UUID_COLUMN: '1',
72
- 'erased': True,
73
- 'score': 4.1,
74
- 'document': {
75
- 'num_pages': 4,
76
- 'header': {
77
- 'title': 'c'
78
- }
79
- }
80
- }, {
81
- UUID_COLUMN: '2',
82
- 'erased': False,
83
- 'score': 3.5,
84
- 'document': {
85
- 'num_pages': 5,
86
- 'header': {
87
- 'title': 'b'
88
- }
89
- },
90
- }, {
91
- UUID_COLUMN: '3',
92
- 'erased': True,
93
- 'score': 3.7,
94
- 'document': {
95
- 'num_pages': 3,
96
- 'header': {
97
- 'title': 'a'
98
- }
99
- },
100
- }])
101
-
102
- # Sort by bool.
103
- result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.ASC)
104
- assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
105
- result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['erased'], sort_order=SortOrder.DESC)
106
- assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
107
-
108
- # Sort by float.
109
- result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.ASC)
110
- assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}]
111
- result = dataset.select_rows(columns=[UUID_COLUMN], sort_by=['score'], sort_order=SortOrder.DESC)
112
- assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
113
-
114
- # Sort by nested int.
115
- result = dataset.select_rows(
116
- columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.ASC)
117
- assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}]
118
- result = dataset.select_rows(
119
- columns=[UUID_COLUMN], sort_by=['document.num_pages'], sort_order=SortOrder.DESC)
120
- assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
121
-
122
- # Sort by double nested string.
123
- result = dataset.select_rows(
124
- columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.ASC)
125
- assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}]
126
- result = dataset.select_rows(
127
- columns=[UUID_COLUMN], sort_by=['document.header.title'], sort_order=SortOrder.DESC)
128
- assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}]
129
-
130
-
131
- def test_sort_by_signal_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
132
- dataset = make_test_data([{
133
- UUID_COLUMN: '1',
134
- 'text': 'HEY'
135
- }, {
136
- UUID_COLUMN: '2',
137
- 'text': 'everyone'
138
- }, {
139
- UUID_COLUMN: '3',
140
- 'text': 'HI'
141
- }])
142
-
143
- dataset.compute_signal(TestSignal(), 'text')
144
-
145
- # Sort by `signal.len`.
146
- result = dataset.select_rows(
147
- columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.ASC)
148
- assert list(result) == [{UUID_COLUMN: '3'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}]
149
- result = dataset.select_rows(
150
- columns=[UUID_COLUMN], sort_by=['text.test_signal.len'], sort_order=SortOrder.DESC)
151
- assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
152
-
153
- # Sort by `signal.is_all_cap`.
154
- result = dataset.select_rows(
155
- columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.ASC)
156
- assert list(result) == [{UUID_COLUMN: '2'}, {UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}]
157
- result = dataset.select_rows(
158
- columns=[UUID_COLUMN], sort_by=['text.test_signal.is_all_cap'], sort_order=SortOrder.DESC)
159
- assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '3'}, {UUID_COLUMN: '2'}]
160
-
161
-
162
- def test_sort_by_signal_alias_no_repeated(make_test_data: TestDataMaker) -> None:
163
- dataset = make_test_data([{
164
- UUID_COLUMN: '1',
165
- 'text': 'HEY'
166
- }, {
167
- UUID_COLUMN: '2',
168
- 'text': 'everyone'
169
- }, {
170
- UUID_COLUMN: '3',
171
- 'text': 'HI'
172
- }])
173
-
174
- dataset.compute_signal(TestSignal(), 'text')
175
-
176
- # Sort by `signal.len`.
177
- signal_alias = Column('text.test_signal', alias='signal')
178
- result = dataset.select_rows(
179
- columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.ASC)
180
- assert list(result) == [{
181
- UUID_COLUMN: '3',
182
- 'signal': {
183
- 'len': 2,
184
- 'is_all_cap': True
185
- }
186
- }, {
187
- UUID_COLUMN: '1',
188
- 'signal': {
189
- 'len': 3,
190
- 'is_all_cap': True
191
- }
192
- }, {
193
- UUID_COLUMN: '2',
194
- 'signal': {
195
- 'len': 8,
196
- 'is_all_cap': False
197
- }
198
- }]
199
- result = dataset.select_rows(
200
- columns=[signal_alias], sort_by=['signal.len'], sort_order=SortOrder.DESC)
201
- assert list(result) == [{
202
- UUID_COLUMN: '2',
203
- 'signal': {
204
- 'len': 8,
205
- 'is_all_cap': False
206
- }
207
- }, {
208
- UUID_COLUMN: '1',
209
- 'signal': {
210
- 'len': 3,
211
- 'is_all_cap': True
212
- }
213
- }, {
214
- UUID_COLUMN: '3',
215
- 'signal': {
216
- 'len': 2,
217
- 'is_all_cap': True
218
- }
219
- }]
220
-
221
-
222
- def test_sort_by_enriched_alias_no_repeated(make_test_data: TestDataMaker) -> None:
223
- dataset = make_test_data([{
224
- UUID_COLUMN: '1',
225
- 'text': 'HEY'
226
- }, {
227
- UUID_COLUMN: '2',
228
- 'text': 'everyone'
229
- }, {
230
- UUID_COLUMN: '3',
231
- 'text': 'HI'
232
- }])
233
-
234
- dataset.compute_signal(TestSignal(), 'text')
235
-
236
- # Sort by `document.test_signal.is_all_cap` where 'document' is an alias to 'text'.
237
- text_alias = Column('text', alias='document')
238
- result = dataset.select_rows(
239
- columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.ASC)
240
- assert list(result) == [{
241
- UUID_COLUMN: '2',
242
- 'document': enriched_item('everyone', {'test_signal': {
243
- 'len': 8,
244
- 'is_all_cap': False
245
- }})
246
- }, {
247
- UUID_COLUMN: '1',
248
- 'document': enriched_item('HEY', {'test_signal': {
249
- 'len': 3,
250
- 'is_all_cap': True
251
- }})
252
- }, {
253
- UUID_COLUMN: '3',
254
- 'document': enriched_item('HI', {'test_signal': {
255
- 'len': 2,
256
- 'is_all_cap': True
257
- }})
258
- }]
259
-
260
- result = dataset.select_rows(
261
- columns=[text_alias], sort_by=['document.test_signal.is_all_cap'], sort_order=SortOrder.DESC)
262
- assert list(result) == [{
263
- UUID_COLUMN: '1',
264
- 'document': enriched_item('HEY', {'test_signal': {
265
- 'len': 3,
266
- 'is_all_cap': True
267
- }})
268
- }, {
269
- UUID_COLUMN: '3',
270
- 'document': enriched_item('HI', {'test_signal': {
271
- 'len': 2,
272
- 'is_all_cap': True
273
- }})
274
- }, {
275
- UUID_COLUMN: '2',
276
- 'document': enriched_item('everyone', {'test_signal': {
277
- 'len': 8,
278
- 'is_all_cap': False
279
- }})
280
- }]
281
-
282
-
283
- def test_sort_by_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None:
284
- dataset = make_test_data([{
285
- UUID_COLUMN: '1',
286
- 'text': 'HEY'
287
- }, {
288
- UUID_COLUMN: '2',
289
- 'text': 'everyone'
290
- }, {
291
- UUID_COLUMN: '3',
292
- 'text': 'HI'
293
- }])
294
-
295
- # Equivalent to: SELECT `TestSignal(text) AS udf`.
296
- text_udf = Column('text', signal_udf=TestSignal(), alias='udf')
297
- # Sort by `udf.len`, where `udf` is an alias to `TestSignal(text)`.
298
- result = dataset.select_rows(['*', text_udf], sort_by=['udf.len'], sort_order=SortOrder.ASC)
299
- assert list(result) == [{
300
- UUID_COLUMN: '3',
301
- 'text': 'HI',
302
- 'udf': {
303
- 'len': 2,
304
- 'is_all_cap': True
305
- }
306
- }, {
307
- UUID_COLUMN: '1',
308
- 'text': 'HEY',
309
- 'udf': {
310
- 'len': 3,
311
- 'is_all_cap': True
312
- }
313
- }, {
314
- UUID_COLUMN: '2',
315
- 'text': 'everyone',
316
- 'udf': {
317
- 'len': 8,
318
- 'is_all_cap': False
319
- }
320
- }]
321
-
322
-
323
- def test_sort_by_udf_no_alias_no_repeated(make_test_data: TestDataMaker) -> None:
324
- dataset = make_test_data([{
325
- UUID_COLUMN: '1',
326
- 'text': 'HEY'
327
- }, {
328
- UUID_COLUMN: '2',
329
- 'text': 'everyone'
330
- }, {
331
- UUID_COLUMN: '3',
332
- 'text': 'HI'
333
- }])
334
-
335
- text_udf = Column('text', signal_udf=TestSignal())
336
- # Sort by `text.test_signal.len`, produced by executing the udf `TestSignal(text)`.
337
- result = dataset.select_rows(['*', text_udf],
338
- sort_by=[('text', 'test_signal', 'len')],
339
- sort_order=SortOrder.ASC,
340
- combine_columns=True)
341
- assert list(result) == [{
342
- UUID_COLUMN: '3',
343
- 'text': enriched_item('HI', {'test_signal': {
344
- 'len': 2,
345
- 'is_all_cap': True
346
- }}),
347
- }, {
348
- UUID_COLUMN: '1',
349
- 'text': enriched_item('HEY', {'test_signal': {
350
- 'len': 3,
351
- 'is_all_cap': True
352
- }}),
353
- }, {
354
- UUID_COLUMN: '2',
355
- 'text': enriched_item('everyone', {'test_signal': {
356
- 'len': 8,
357
- 'is_all_cap': False
358
- }}),
359
- }]
360
-
361
- # Sort descending.
362
- result = dataset.select_rows(['*', text_udf],
363
- sort_by=[('text', 'test_signal', 'len')],
364
- sort_order=SortOrder.DESC,
365
- combine_columns=True)
366
- assert list(result) == [{
367
- UUID_COLUMN: '2',
368
- 'text': enriched_item('everyone', {'test_signal': {
369
- 'len': 8,
370
- 'is_all_cap': False
371
- }}),
372
- }, {
373
- UUID_COLUMN: '1',
374
- 'text': enriched_item('HEY', {'test_signal': {
375
- 'len': 3,
376
- 'is_all_cap': True
377
- }}),
378
- }, {
379
- UUID_COLUMN: '3',
380
- 'text': enriched_item('HI', {'test_signal': {
381
- 'len': 2,
382
- 'is_all_cap': True
383
- }}),
384
- }]
385
-
386
-
387
- def test_sort_by_primitive_udf_alias_no_repeated(make_test_data: TestDataMaker) -> None:
388
- dataset = make_test_data([{
389
- UUID_COLUMN: '1',
390
- 'text': 'HEY'
391
- }, {
392
- UUID_COLUMN: '2',
393
- 'text': 'everyone'
394
- }, {
395
- UUID_COLUMN: '3',
396
- 'text': 'HI'
397
- }])
398
-
399
- # Equivalent to: SELECT `TestPrimitiveSignal(text) AS udf`.
400
- text_udf = Column('text', signal_udf=TestPrimitiveSignal(), alias='udf')
401
- # Sort by the primitive value returned by the udf.
402
- result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.ASC)
403
- assert list(result) == [{
404
- UUID_COLUMN: '3',
405
- 'text': 'HI',
406
- 'udf': 3
407
- }, {
408
- UUID_COLUMN: '1',
409
- 'text': 'HEY',
410
- 'udf': 4
411
- }, {
412
- UUID_COLUMN: '2',
413
- 'text': 'everyone',
414
- 'udf': 9
415
- }]
416
-
417
-
418
- def test_sort_by_source_non_leaf_errors(make_test_data: TestDataMaker) -> None:
419
- dataset = make_test_data([{
420
- UUID_COLUMN: '1',
421
- 'vals': [7, 1]
422
- }, {
423
- UUID_COLUMN: '2',
424
- 'vals': [3, 4]
425
- }, {
426
- UUID_COLUMN: '3',
427
- 'vals': [9, 0]
428
- }])
429
-
430
- # Sort by repeated.
431
- with pytest.raises(ValueError, match='Unable to sort by path'):
432
- dataset.select_rows(columns=[UUID_COLUMN], sort_by=['vals'], sort_order=SortOrder.ASC)
433
-
434
-
435
- def test_sort_by_source_no_alias_repeated(make_test_data: TestDataMaker) -> None:
436
- dataset = make_test_data([{
437
- UUID_COLUMN: '1',
438
- 'vals': [[{
439
- 'score': 7
440
- }, {
441
- 'score': 1
442
- }], [{
443
- 'score': 1
444
- }, {
445
- 'score': 7
446
- }]]
447
- }, {
448
- UUID_COLUMN: '2',
449
- 'vals': [[{
450
- 'score': 3
451
- }, {
452
- 'score': 4
453
- }]]
454
- }, {
455
- UUID_COLUMN: '3',
456
- 'vals': [[{
457
- 'score': 9
458
- }, {
459
- 'score': 0
460
- }]]
461
- }])
462
-
463
- # Sort by repeated 'vals'.
464
- result = dataset.select_rows(
465
- columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.ASC)
466
- assert list(result) == [{
467
- UUID_COLUMN: '3',
468
- 'vals': [[{
469
- 'score': 9
470
- }, {
471
- 'score': 0
472
- }]]
473
- }, {
474
- UUID_COLUMN: '1',
475
- 'vals': [[{
476
- 'score': 7
477
- }, {
478
- 'score': 1
479
- }], [{
480
- 'score': 1
481
- }, {
482
- 'score': 7
483
- }]]
484
- }, {
485
- UUID_COLUMN: '2',
486
- 'vals': [[{
487
- 'score': 3
488
- }, {
489
- 'score': 4
490
- }]]
491
- }]
492
-
493
- result = dataset.select_rows(
494
- columns=[UUID_COLUMN, 'vals'], sort_by=['vals.*.*.score'], sort_order=SortOrder.DESC)
495
- assert list(result) == [{
496
- UUID_COLUMN: '3',
497
- 'vals': [[{
498
- 'score': 9
499
- }, {
500
- 'score': 0
501
- }]]
502
- }, {
503
- UUID_COLUMN: '1',
504
- 'vals': [[{
505
- 'score': 7
506
- }, {
507
- 'score': 1
508
- }], [{
509
- 'score': 1
510
- }, {
511
- 'score': 7
512
- }]]
513
- }, {
514
- UUID_COLUMN: '2',
515
- 'vals': [[{
516
- 'score': 3
517
- }, {
518
- 'score': 4
519
- }]]
520
- }]
521
-
522
-
523
- def test_sort_by_source_alias_repeated(make_test_data: TestDataMaker) -> None:
524
- dataset = make_test_data([{
525
- UUID_COLUMN: '1',
526
- 'vals': [[7, 1], [1, 7]]
527
- }, {
528
- UUID_COLUMN: '2',
529
- 'vals': [[3], [11]]
530
- }, {
531
- UUID_COLUMN: '3',
532
- 'vals': [[9, 0]]
533
- }])
534
-
535
- # Sort by repeated 'vals'.
536
- result = dataset.select_rows(
537
- columns=[UUID_COLUMN, Column('vals', alias='scores')],
538
- sort_by=['scores.*.*'],
539
- sort_order=SortOrder.ASC)
540
- assert list(result) == [{
541
- UUID_COLUMN: '3',
542
- 'scores': [[9, 0]]
543
- }, {
544
- UUID_COLUMN: '1',
545
- 'scores': [[7, 1], [1, 7]]
546
- }, {
547
- UUID_COLUMN: '2',
548
- 'scores': [[3], [11]]
549
- }]
550
-
551
- result = dataset.select_rows(
552
- columns=[UUID_COLUMN, Column('vals', alias='scores')],
553
- sort_by=['scores.*.*'],
554
- sort_order=SortOrder.DESC)
555
- assert list(result) == [{
556
- UUID_COLUMN: '2',
557
- 'scores': [[3], [11]]
558
- }, {
559
- UUID_COLUMN: '3',
560
- 'scores': [[9, 0]]
561
- }, {
562
- UUID_COLUMN: '1',
563
- 'scores': [[7, 1], [1, 7]]
564
- }]
565
-
566
-
567
- def test_sort_by_udf_alias_repeated(make_test_data: TestDataMaker) -> None:
568
- dataset = make_test_data([{
569
- UUID_COLUMN: '1',
570
- 'text': 'HEY'
571
- }, {
572
- UUID_COLUMN: '2',
573
- 'text': 'everyone'
574
- }, {
575
- UUID_COLUMN: '3',
576
- 'text': 'HI'
577
- }])
578
-
579
- # Equivalent to: SELECT `NestedArraySignal(text) AS udf`.
580
- text_udf = Column('text', signal_udf=NestedArraySignal(), alias='udf')
581
- # Sort by `udf.*.*`, where `udf` is an alias to `NestedArraySignal(text)`.
582
- result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.ASC)
583
- assert list(result) == [{
584
- UUID_COLUMN: '3',
585
- 'text': 'HI',
586
- 'udf': [[3], [2]]
587
- }, {
588
- UUID_COLUMN: '1',
589
- 'text': 'HEY',
590
- 'udf': [[4], [3]]
591
- }, {
592
- UUID_COLUMN: '2',
593
- 'text': 'everyone',
594
- 'udf': [[9], [8]]
595
- }]
596
- result = dataset.select_rows(['*', text_udf], sort_by=['udf.*.*'], sort_order=SortOrder.DESC)
597
- assert list(result) == [{
598
- UUID_COLUMN: '2',
599
- 'text': 'everyone',
600
- 'udf': [[9], [8]]
601
- }, {
602
- UUID_COLUMN: '1',
603
- 'text': 'HEY',
604
- 'udf': [[4], [3]]
605
- }, {
606
- UUID_COLUMN: '3',
607
- 'text': 'HI',
608
- 'udf': [[3], [2]]
609
- }]
610
-
611
-
612
- def test_sort_by_complex_signal_udf_alias_called_on_repeated(make_test_data: TestDataMaker) -> None:
613
- dataset = make_test_data([{
614
- UUID_COLUMN: '1',
615
- 'texts': [{
616
- 'text': 'eardrop'
617
- }, {
618
- 'text': 'I'
619
- }]
620
- }, {
621
- UUID_COLUMN: '2',
622
- 'texts': [{
623
- 'text': 'hey'
624
- }, {
625
- 'text': 'CARS'
626
- }]
627
- }, {
628
- UUID_COLUMN: '3',
629
- 'texts': [{
630
- 'text': 'everyone'
631
- }, {
632
- 'text': ''
633
- }]
634
- }])
635
-
636
- # Equivalent to: SELECT `TestSignal(texts.*.text) AS udf`.
637
- texts_udf = Column('texts.*.text', signal_udf=TestSignal(), alias='udf')
638
- # Sort by `udf.len`, where `udf` is an alias to `TestSignal(texts.*.text)`.
639
- result = dataset.select_rows(['*', texts_udf],
640
- sort_by=['udf.len'],
641
- sort_order=SortOrder.ASC,
642
- combine_columns=True)
643
- assert list(result) == [{
644
- UUID_COLUMN: '3',
645
- 'texts': [{
646
- 'text': enriched_item('everyone', {'test_signal': {
647
- 'len': 8,
648
- 'is_all_cap': False
649
- }})
650
- }, {
651
- 'text': enriched_item('', {'test_signal': {
652
- 'len': 0,
653
- 'is_all_cap': False
654
- }})
655
- }]
656
- }, {
657
- UUID_COLUMN: '1',
658
- 'texts': [{
659
- 'text': enriched_item('eardrop', {'test_signal': {
660
- 'len': 7,
661
- 'is_all_cap': False
662
- }})
663
- }, {
664
- 'text': enriched_item('I', {'test_signal': {
665
- 'len': 1,
666
- 'is_all_cap': True
667
- }})
668
- }]
669
- }, {
670
- UUID_COLUMN: '2',
671
- 'texts': [{
672
- 'text': enriched_item('hey', {'test_signal': {
673
- 'len': 3,
674
- 'is_all_cap': False
675
- }})
676
- }, {
677
- 'text': enriched_item('CARS', {'test_signal': {
678
- 'len': 4,
679
- 'is_all_cap': True
680
- }})
681
- }]
682
- }]
683
-
684
-
685
- def test_sort_by_primitive_signal_udf_alias_called_on_repeated(
686
- make_test_data: TestDataMaker) -> None:
687
- dataset = make_test_data([{
688
- UUID_COLUMN: '1',
689
- 'texts': [{
690
- 'text': 'eardrop'
691
- }, {
692
- 'text': 'I'
693
- }]
694
- }, {
695
- UUID_COLUMN: '2',
696
- 'texts': [{
697
- 'text': 'hey'
698
- }, {
699
- 'text': 'CARS'
700
- }]
701
- }, {
702
- UUID_COLUMN: '3',
703
- 'texts': [{
704
- 'text': 'everyone'
705
- }, {
706
- 'text': ''
707
- }]
708
- }])
709
-
710
- # Equivalent to: SELECT `TestPrimitiveSignal(texts.*.text) AS udf`.
711
- texts_udf = Column('texts.*.text', signal_udf=TestPrimitiveSignal(), alias='udf')
712
- # Sort by `udf`, where `udf` is an alias to `TestPrimitiveSignal(texts.*.text)`.
713
- result = dataset.select_rows(['*', texts_udf],
714
- sort_by=['udf'],
715
- sort_order=SortOrder.ASC,
716
- combine_columns=True)
717
- assert list(result) == [{
718
- UUID_COLUMN: '3',
719
- 'texts': [{
720
- 'text': enriched_item('everyone', {'primitive_signal': 9})
721
- }, {
722
- 'text': enriched_item('', {'primitive_signal': 1})
723
- }]
724
- }, {
725
- UUID_COLUMN: '1',
726
- 'texts': [{
727
- 'text': enriched_item('eardrop', {'primitive_signal': 8})
728
- }, {
729
- 'text': enriched_item('I', {'primitive_signal': 2})
730
- }]
731
- }, {
732
- UUID_COLUMN: '2',
733
- 'texts': [{
734
- 'text': enriched_item('hey', {'primitive_signal': 4})
735
- }, {
736
- 'text': enriched_item('CARS', {'primitive_signal': 5})
737
- }]
738
- }]
739
- result = dataset.select_rows(['*', texts_udf],
740
- sort_by=['udf'],
741
- sort_order=SortOrder.DESC,
742
- combine_columns=True)
743
- assert list(result) == [{
744
- UUID_COLUMN: '3',
745
- 'texts': [{
746
- 'text': enriched_item('everyone', {'primitive_signal': 9})
747
- }, {
748
- 'text': enriched_item('', {'primitive_signal': 1})
749
- }]
750
- }, {
751
- UUID_COLUMN: '1',
752
- 'texts': [{
753
- 'text': enriched_item('eardrop', {'primitive_signal': 8})
754
- }, {
755
- 'text': enriched_item('I', {'primitive_signal': 2})
756
- }]
757
- }, {
758
- UUID_COLUMN: '2',
759
- 'texts': [{
760
- 'text': enriched_item('hey', {'primitive_signal': 4})
761
- }, {
762
- 'text': enriched_item('CARS', {'primitive_signal': 5})
763
- }]
764
- }]
765
-
766
-
767
- class TopKEmbedding(TextEmbeddingSignal):
768
- """A test embed function."""
769
- name = 'topk_embedding'
770
-
771
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
772
- """Call the embedding function."""
773
- for example in data:
774
- example = cast(str, example)
775
- emb_spans: list[Item] = []
776
- for i, score in enumerate(example.split('_')):
777
- start, end = i * 2, i * 2 + 1
778
- vector = np.array([int(score)])
779
- emb_spans.append(lilac_embedding(start, end, vector))
780
- yield emb_spans
781
-
782
-
783
- class TopKSignal(TextEmbeddingModelSignal):
784
- """Compute scores along a given concept for documents."""
785
- name = 'topk_signal'
786
-
787
- _query = np.array([1])
788
-
789
- def fields(self) -> Field:
790
- return field('float32')
791
-
792
- @override
793
- def vector_compute(self, keys: Iterable[VectorKey],
794
- vector_store: VectorStore) -> Iterable[Optional[Item]]:
795
- text_embeddings = vector_store.get(keys)
796
- dot_products = text_embeddings.dot(self._query).reshape(-1)
797
- return dot_products.tolist()
798
-
799
- @override
800
- def vector_compute_topk(
801
- self,
802
- topk: int,
803
- vector_store: VectorStore,
804
- keys: Optional[Iterable[VectorKey]] = None) -> Sequence[tuple[VectorKey, Optional[Item]]]:
805
- return vector_store.topk(self._query, topk, keys)
806
-
807
-
808
- def test_sort_by_topk_embedding_udf(make_test_data: TestDataMaker) -> None:
809
- dataset = make_test_data([{
810
- UUID_COLUMN: '1',
811
- 'scores': '8_1',
812
- }, {
813
- UUID_COLUMN: '2',
814
- 'scores': '3_5'
815
- }, {
816
- UUID_COLUMN: '3',
817
- 'scores': '9_7'
818
- }])
819
-
820
- dataset.compute_signal(TopKEmbedding(), 'scores')
821
-
822
- # Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`.
823
- text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf')
824
- # Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`.
825
- result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=3)
826
- assert list(result) == [{
827
- UUID_COLUMN: '3',
828
- 'scores': enriched_item(
829
- '9_7', {'topk_embedding': [lilac_embedding(0, 1, None),
830
- lilac_embedding(2, 3, None)]}),
831
- 'udf': [9.0, 7.0]
832
- }, {
833
- UUID_COLUMN: '1',
834
- 'scores': enriched_item(
835
- '8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
836
- lilac_embedding(2, 3, None)]}),
837
- 'udf': [8.0, 1.0]
838
- }]
839
-
840
- # Same but set limit to 4.
841
- result = dataset.select_rows(['*', text_udf], sort_by=['udf'], sort_order=SortOrder.DESC, limit=4)
842
- assert list(result) == [{
843
- UUID_COLUMN: '3',
844
- 'scores': enriched_item(
845
- '9_7', {'topk_embedding': [lilac_embedding(0, 1, None),
846
- lilac_embedding(2, 3, None)]}),
847
- 'udf': [9.0, 7.0]
848
- }, {
849
- UUID_COLUMN: '1',
850
- 'scores': enriched_item(
851
- '8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
852
- lilac_embedding(2, 3, None)]}),
853
- 'udf': [8.0, 1.0]
854
- }, {
855
- UUID_COLUMN: '2',
856
- 'scores': enriched_item(
857
- '3_5', {'topk_embedding': [lilac_embedding(0, 1, None),
858
- lilac_embedding(2, 3, None)]}),
859
- 'udf': [3.0, 5.0]
860
- }]
861
-
862
-
863
- def test_sort_by_topk_udf_with_filter(make_test_data: TestDataMaker) -> None:
864
- dataset = make_test_data([{
865
- UUID_COLUMN: '1',
866
- 'scores': '8_1',
867
- 'active': True
868
- }, {
869
- UUID_COLUMN: '2',
870
- 'scores': '3_5',
871
- 'active': True
872
- }, {
873
- UUID_COLUMN: '3',
874
- 'scores': '9_7',
875
- 'active': False
876
- }])
877
-
878
- dataset.compute_signal(TopKEmbedding(), 'scores')
879
-
880
- # Equivalent to: SELECT `TopKSignal(scores, embedding='...') AS udf`.
881
- text_udf = Column('scores', signal_udf=TopKSignal(embedding='topk_embedding'), alias='udf')
882
- # Sort by `udf`, where `udf` is an alias to `TopKSignal(scores, embedding='...')`.
883
- result = dataset.select_rows(['*', text_udf],
884
- sort_by=['udf'],
885
- filters=[('active', BinaryOp.EQUALS, True)],
886
- sort_order=SortOrder.DESC,
887
- limit=2)
888
- # We make sure that '3' is not in the result, because it is not active, even though it has the
889
- # highest topk score.
890
- assert list(result) == [{
891
- UUID_COLUMN: '1',
892
- 'active': True,
893
- 'scores': enriched_item(
894
- '8_1', {'topk_embedding': [lilac_embedding(0, 1, None),
895
- lilac_embedding(2, 3, None)]}),
896
- 'udf': [8.0, 1.0]
897
- }, {
898
- UUID_COLUMN: '2',
899
- 'active': True,
900
- 'scores': enriched_item(
901
- '3_5', {'topk_embedding': [lilac_embedding(0, 1, None),
902
- lilac_embedding(2, 3, None)]}),
903
- 'udf': [3.0, 5.0]
904
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_select_rows_udf_test.py DELETED
@@ -1,404 +0,0 @@
1
- """Tests for dataset.select_rows(udf_col)."""
2
-
3
- from typing import Iterable, Optional, cast
4
-
5
- import numpy as np
6
- import pytest
7
- from typing_extensions import override
8
-
9
- from ..embeddings.vector_store import VectorStore
10
- from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, VectorKey, field
11
- from ..signals.signal import (
12
- TextEmbeddingModelSignal,
13
- TextEmbeddingSignal,
14
- TextSignal,
15
- TextSplitterSignal,
16
- clear_signal_registry,
17
- register_signal,
18
- )
19
- from .dataset import BinaryFilterTuple, BinaryOp, Column, val
20
- from .dataset_test_utils import TestDataMaker, enriched_item
21
- from .dataset_utils import lilac_embedding, lilac_span
22
-
23
- EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
24
- ('hello2.', [1.0, 1.0, 0.0]),
25
- ('hello world.', [1.0, 1.0, 1.0]),
26
- ('hello world2.', [2.0, 1.0, 1.0])]
27
-
28
- STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
29
-
30
-
31
- class TestEmbedding(TextEmbeddingSignal):
32
- """A test embed function."""
33
- name = 'test_embedding'
34
-
35
- @override
36
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
37
- """Call the embedding function."""
38
- for example in data:
39
- yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
40
-
41
-
42
- class LengthSignal(TextSignal):
43
- name = 'length_signal'
44
-
45
- _call_count: int = 0
46
-
47
- def fields(self) -> Field:
48
- return field('int32')
49
-
50
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
51
- for text_content in data:
52
- self._call_count += 1
53
- yield len(text_content)
54
-
55
-
56
- class TestSignal(TextSignal):
57
- name = 'test_signal'
58
-
59
- @override
60
- def fields(self) -> Field:
61
- return field(fields={'len': 'int32', 'flen': 'float32'})
62
-
63
- @override
64
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
65
- return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data]
66
-
67
-
68
- class TestEmbeddingSumSignal(TextEmbeddingModelSignal):
69
- """Sums the embeddings to return a single floating point value."""
70
- name = 'test_embedding_sum'
71
-
72
- @override
73
- def fields(self) -> Field:
74
- return field('float32')
75
-
76
- @override
77
- def vector_compute(self, keys: Iterable[VectorKey], vector_store: VectorStore) -> Iterable[Item]:
78
- # The signal just sums the values of the embedding.
79
- embedding_sums = vector_store.get(keys).sum(axis=1)
80
- for embedding_sum in embedding_sums.tolist():
81
- yield embedding_sum
82
-
83
-
84
- class ComputedKeySignal(TextSignal):
85
- name = 'computed_key'
86
-
87
- @override
88
- def fields(self) -> Field:
89
- return field('int64')
90
-
91
- @override
92
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
93
- for text in data:
94
- yield 1
95
-
96
- def key(self, is_computed_signal: Optional[bool] = False) -> str:
97
- return f'key_{is_computed_signal}'
98
-
99
-
100
- @pytest.fixture(scope='module', autouse=True)
101
- def setup_teardown() -> Iterable[None]:
102
- # Setup.
103
- register_signal(LengthSignal)
104
- register_signal(TestSplitter)
105
- register_signal(TestEmbedding)
106
- register_signal(TestSignal)
107
- register_signal(TestEmbeddingSumSignal)
108
- register_signal(ComputedKeySignal)
109
-
110
- # Unit test runs.
111
- yield
112
- # Teardown.
113
- clear_signal_registry()
114
-
115
-
116
- def test_udf(make_test_data: TestDataMaker) -> None:
117
- dataset = make_test_data([{
118
- UUID_COLUMN: '1',
119
- 'text': 'hello'
120
- }, {
121
- UUID_COLUMN: '2',
122
- 'text': 'everybody'
123
- }])
124
-
125
- signal_col = Column('text', signal_udf=TestSignal())
126
- result = dataset.select_rows(['text', signal_col])
127
-
128
- assert list(result) == [{
129
- UUID_COLUMN: '1',
130
- 'text': 'hello',
131
- 'test_signal(text)': {
132
- 'len': 5,
133
- 'flen': 5.0
134
- }
135
- }, {
136
- UUID_COLUMN: '2',
137
- 'text': 'everybody',
138
- 'test_signal(text)': {
139
- 'len': 9,
140
- 'flen': 9.0
141
- }
142
- }]
143
-
144
-
145
- def test_udf_with_filters(make_test_data: TestDataMaker) -> None:
146
- dataset = make_test_data([{
147
- UUID_COLUMN: '1',
148
- 'text': 'hello'
149
- }, {
150
- UUID_COLUMN: '2',
151
- 'text': 'everybody'
152
- }])
153
-
154
- signal_col = Column('text', signal_udf=TestSignal())
155
- # Filter by source feature.
156
- filters: list[BinaryFilterTuple] = [('text', BinaryOp.EQUALS, 'everybody')]
157
- result = dataset.select_rows(['text', signal_col], filters=filters)
158
- assert list(result) == [{
159
- UUID_COLUMN: '2',
160
- 'text': 'everybody',
161
- 'test_signal(text)': {
162
- 'len': 9,
163
- 'flen': 9.0
164
- }
165
- }]
166
-
167
-
168
- def test_udf_with_uuid_filter(make_test_data: TestDataMaker) -> None:
169
-
170
- dataset = make_test_data([{
171
- UUID_COLUMN: '1',
172
- 'text': 'hello'
173
- }, {
174
- UUID_COLUMN: '2',
175
- 'text': 'everybody'
176
- }])
177
-
178
- # Filter by a specific UUID.
179
- filters: list[BinaryFilterTuple] = [(UUID_COLUMN, BinaryOp.EQUALS, '1')]
180
- udf_col = Column('text', signal_udf=LengthSignal())
181
- result = dataset.select_rows(['text', udf_col], filters=filters)
182
- assert list(result) == [{UUID_COLUMN: '1', 'text': 'hello', 'length_signal(text)': 5}]
183
- assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1
184
-
185
- filters = [(UUID_COLUMN, BinaryOp.EQUALS, '2')]
186
- result = dataset.select_rows(['text', udf_col], filters=filters)
187
- assert list(result) == [{UUID_COLUMN: '2', 'text': 'everybody', 'length_signal(text)': 9}]
188
- assert cast(LengthSignal, udf_col.signal_udf)._call_count == 1 + 1
189
-
190
- # No filters.
191
- result = dataset.select_rows(['text', udf_col])
192
- assert list(result) == [{
193
- UUID_COLUMN: '1',
194
- 'text': 'hello',
195
- 'length_signal(text)': 5
196
- }, {
197
- UUID_COLUMN: '2',
198
- 'text': 'everybody',
199
- 'length_signal(text)': 9
200
- }]
201
- assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 2
202
-
203
-
204
- def test_udf_with_uuid_filter_repeated(make_test_data: TestDataMaker) -> None:
205
-
206
- dataset = make_test_data([{
207
- UUID_COLUMN: '1',
208
- 'text': ['hello', 'hi']
209
- }, {
210
- UUID_COLUMN: '2',
211
- 'text': ['everybody', 'bye', 'test']
212
- }])
213
-
214
- # Filter by a specific UUID.
215
- filters: list[BinaryFilterTuple] = [(UUID_COLUMN, BinaryOp.EQUALS, '1')]
216
- udf_col = Column(('text', '*'), signal_udf=LengthSignal())
217
- result = dataset.select_rows(['text', udf_col], filters=filters)
218
- assert list(result) == [{
219
- UUID_COLUMN: '1',
220
- 'text': ['hello', 'hi'],
221
- 'length_signal(text)': [5, 2]
222
- }]
223
- assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2
224
-
225
- # Filter by a specific UUID.
226
- filters = [(UUID_COLUMN, BinaryOp.EQUALS, '2')]
227
- result = dataset.select_rows(['text', udf_col], filters=filters)
228
- assert list(result) == [{
229
- UUID_COLUMN: '2',
230
- 'text': ['everybody', 'bye', 'test'],
231
- 'length_signal(text)': [9, 3, 4]
232
- }]
233
- assert cast(LengthSignal, udf_col.signal_udf)._call_count == 2 + 3
234
-
235
-
236
- def test_udf_deeply_nested(make_test_data: TestDataMaker) -> None:
237
- dataset = make_test_data([{
238
- UUID_COLUMN: '1',
239
- 'text': [['hello'], ['hi', 'bye']]
240
- }, {
241
- UUID_COLUMN: '2',
242
- 'text': [['everybody', 'bye'], ['test']]
243
- }])
244
-
245
- udf_col = Column(('text', '*', '*'), signal_udf=LengthSignal())
246
- result = dataset.select_rows([udf_col])
247
- assert list(result) == [{
248
- UUID_COLUMN: '1',
249
- 'length_signal(text.*)': [[5], [2, 3]]
250
- }, {
251
- UUID_COLUMN: '2',
252
- 'length_signal(text.*)': [[9, 3], [4]]
253
- }]
254
- assert cast(LengthSignal, udf_col.signal_udf)._call_count == 6
255
-
256
-
257
- def test_udf_with_embedding(make_test_data: TestDataMaker) -> None:
258
- dataset = make_test_data([{
259
- UUID_COLUMN: '1',
260
- 'text': 'hello.',
261
- }, {
262
- UUID_COLUMN: '2',
263
- 'text': 'hello2.',
264
- }])
265
-
266
- dataset.compute_signal(TestEmbedding(), 'text')
267
-
268
- signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
269
- result = dataset.select_rows([val('text'), signal_col])
270
-
271
- expected_result: list[Item] = [{
272
- UUID_COLUMN: '1',
273
- f'text.{VALUE_KEY}': 'hello.',
274
- 'test_embedding_sum(text.test_embedding.*.embedding)': [1.0]
275
- }, {
276
- UUID_COLUMN: '2',
277
- f'text.{VALUE_KEY}': 'hello2.',
278
- 'test_embedding_sum(text.test_embedding.*.embedding)': [2.0]
279
- }]
280
- assert list(result) == expected_result
281
-
282
- # Select rows with alias.
283
- signal_col = Column(
284
- 'text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'), alias='emb_sum')
285
- result = dataset.select_rows([val('text'), signal_col])
286
- expected_result = [{
287
- UUID_COLUMN: '1',
288
- f'text.{VALUE_KEY}': 'hello.',
289
- 'emb_sum': [1.0]
290
- }, {
291
- UUID_COLUMN: '2',
292
- f'text.{VALUE_KEY}': 'hello2.',
293
- 'emb_sum': [2.0]
294
- }]
295
- assert list(result) == expected_result
296
-
297
-
298
- def test_udf_with_nested_embedding(make_test_data: TestDataMaker) -> None:
299
- dataset = make_test_data([{
300
- UUID_COLUMN: '1',
301
- 'text': ['hello.', 'hello world.'],
302
- }, {
303
- UUID_COLUMN: '2',
304
- 'text': ['hello world2.', 'hello2.'],
305
- }])
306
-
307
- dataset.compute_signal(TestEmbedding(), ('text', '*'))
308
-
309
- signal_col = Column(('text', '*'), signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
310
- result = dataset.select_rows([val(('text', '*')), signal_col])
311
- expected_result = [{
312
- UUID_COLUMN: '1',
313
- f'text.*.{VALUE_KEY}': ['hello.', 'hello world.'],
314
- 'test_embedding_sum(text.*.test_embedding.*.embedding)': [[1.0], [3.0]]
315
- }, {
316
- UUID_COLUMN: '2',
317
- f'text.*.{VALUE_KEY}': ['hello world2.', 'hello2.'],
318
- 'test_embedding_sum(text.*.test_embedding.*.embedding)': [[4.0], [2.0]]
319
- }]
320
- assert list(result) == expected_result
321
-
322
-
323
- def test_udf_throws_without_precomputing(make_test_data: TestDataMaker) -> None:
324
- dataset = make_test_data([{
325
- UUID_COLUMN: '1',
326
- 'text': 'hello.',
327
- }, {
328
- UUID_COLUMN: '2',
329
- 'text': 'hello2.',
330
- }])
331
-
332
- # Embedding is not precomputed, yet we ask for the embedding.
333
-
334
- signal_col = Column('text', signal_udf=TestEmbeddingSumSignal(embedding='test_embedding'))
335
-
336
- with pytest.raises(ValueError, match='Embedding signal "test_embedding" is not computed'):
337
- dataset.select_rows([val('text'), signal_col])
338
-
339
-
340
- class TestSplitter(TextSplitterSignal):
341
- """Split documents into sentence by splitting on period."""
342
- name = 'test_splitter'
343
-
344
- @override
345
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
346
- for text in data:
347
- if not isinstance(text, str):
348
- raise ValueError(f'Expected text to be a string, got {type(text)} instead.')
349
- result: list[Item] = []
350
- for sentence in text.split('.'):
351
- start = text.index(sentence)
352
- end = start + len(sentence)
353
- result.append(lilac_span(start, end))
354
- yield result
355
-
356
-
357
- def test_udf_after_precomputed_split(make_test_data: TestDataMaker) -> None:
358
- dataset = make_test_data([{
359
- UUID_COLUMN: '1',
360
- 'text': 'sentence 1. sentence 2 is longer',
361
- }, {
362
- UUID_COLUMN: '2',
363
- 'text': 'sentence 1 is longer. sent2 is short',
364
- }])
365
- dataset.compute_signal(TestSplitter(), 'text')
366
- udf = Column('text', signal_udf=LengthSignal())
367
- result = dataset.select_rows(['*', udf], combine_columns=True)
368
- assert list(result) == [{
369
- UUID_COLUMN: '1',
370
- 'text': enriched_item('sentence 1. sentence 2 is longer', {
371
- 'length_signal': 32,
372
- 'test_splitter': [lilac_span(0, 10), lilac_span(11, 32)]
373
- })
374
- }, {
375
- UUID_COLUMN: '2',
376
- 'text': enriched_item('sentence 1 is longer. sent2 is short', {
377
- 'length_signal': 36,
378
- 'test_splitter': [lilac_span(0, 20), lilac_span(21, 36)]
379
- })
380
- }]
381
-
382
-
383
- def test_is_computed_signal_key(make_test_data: TestDataMaker) -> None:
384
- dataset = make_test_data([{
385
- UUID_COLUMN: '1',
386
- 'text': 'hello.',
387
- }, {
388
- UUID_COLUMN: '2',
389
- 'text': 'hello2.',
390
- }])
391
-
392
- signal_col = Column('text', signal_udf=ComputedKeySignal())
393
- # Filter by source feature.
394
- filters: list[BinaryFilterTuple] = [('text', BinaryOp.EQUALS, 'everybody')]
395
- result = dataset.select_rows(['text', signal_col])
396
- assert list(result) == [{
397
- UUID_COLUMN: '1',
398
- 'text': 'hello.',
399
- 'key_False(text)': 1
400
- }, {
401
- UUID_COLUMN: '2',
402
- 'text': 'hello2.',
403
- 'key_False(text)': 1
404
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_stats_test.py DELETED
@@ -1,125 +0,0 @@
1
- """Tests for dataset.stats()."""
2
-
3
- from typing import Any, cast
4
-
5
- import pytest
6
- from pytest_mock import MockerFixture
7
-
8
- from ..schema import UUID_COLUMN, Item, schema
9
- from . import dataset_duckdb
10
- from .dataset import StatsResult
11
- from .dataset_test_utils import TestDataMaker
12
-
13
- SIMPLE_ITEMS: list[Item] = [{
14
- UUID_COLUMN: '1',
15
- 'str': 'a',
16
- 'int': 1,
17
- 'bool': False,
18
- 'float': 3.0,
19
- }, {
20
- UUID_COLUMN: '2',
21
- 'str': 'b',
22
- 'int': 2,
23
- 'bool': True,
24
- 'float': 2.0
25
- }, {
26
- UUID_COLUMN: '3',
27
- 'str': 'b',
28
- 'int': 2,
29
- 'bool': True,
30
- 'float': 1.0
31
- }, {
32
- UUID_COLUMN: '4',
33
- 'float': float('nan')
34
- }]
35
-
36
-
37
- def test_simple_stats(make_test_data: TestDataMaker) -> None:
38
- dataset = make_test_data(SIMPLE_ITEMS)
39
-
40
- result = dataset.stats(leaf_path='str')
41
- assert result == StatsResult(
42
- path=('str',), total_count=3, approx_count_distinct=2, avg_text_length=1)
43
-
44
- result = dataset.stats(leaf_path='float')
45
- assert result == StatsResult(
46
- path=('float',), total_count=4, approx_count_distinct=4, min_val=1.0, max_val=3.0)
47
-
48
- result = dataset.stats(leaf_path='bool')
49
- assert result == StatsResult(path=('bool',), total_count=3, approx_count_distinct=2)
50
-
51
- result = dataset.stats(leaf_path='int')
52
- assert result == StatsResult(
53
- path=('int',), total_count=3, approx_count_distinct=2, min_val=1, max_val=2)
54
-
55
-
56
- def test_nested_stats(make_test_data: TestDataMaker) -> None:
57
- nested_items: list[Item] = [
58
- {
59
- 'name': 'Name1',
60
- 'addresses': [{
61
- 'zips': [5, 8]
62
- }]
63
- },
64
- {
65
- 'name': 'Name2',
66
- 'addresses': [{
67
- 'zips': [3]
68
- }, {
69
- 'zips': [11, 8]
70
- }]
71
- },
72
- {
73
- 'name': 'Name2',
74
- 'addresses': []
75
- }, # No addresses.
76
- {
77
- 'name': 'Name2',
78
- 'addresses': [{
79
- 'zips': []
80
- }]
81
- } # No zips in the first address.
82
- ]
83
- nested_schema = schema({
84
- UUID_COLUMN: 'string',
85
- 'name': 'string',
86
- 'addresses': [{
87
- 'zips': ['int32']
88
- }]
89
- })
90
- dataset = make_test_data(nested_items, schema=nested_schema)
91
-
92
- result = dataset.stats(leaf_path='name')
93
- assert result == StatsResult(
94
- path=('name',), total_count=4, approx_count_distinct=2, avg_text_length=5)
95
-
96
- result = dataset.stats(leaf_path='addresses.*.zips.*')
97
- assert result == StatsResult(
98
- path=('addresses', '*', 'zips', '*'),
99
- total_count=5,
100
- approx_count_distinct=4,
101
- min_val=3,
102
- max_val=11)
103
-
104
-
105
- def test_stats_approximation(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
106
- sample_size = 5
107
- mocker.patch(f'{dataset_duckdb.__name__}.SAMPLE_SIZE_DISTINCT_COUNT', sample_size)
108
-
109
- nested_items: list[Item] = [{'feature': str(i)} for i in range(sample_size * 10)]
110
- nested_schema = schema({UUID_COLUMN: 'string', 'feature': 'string'})
111
- dataset = make_test_data(nested_items, schema=nested_schema)
112
-
113
- result = dataset.stats(leaf_path='feature')
114
- assert result == StatsResult(
115
- path=('feature',), total_count=50, approx_count_distinct=50, avg_text_length=1)
116
-
117
-
118
- def test_error_handling(make_test_data: TestDataMaker) -> None:
119
- dataset = make_test_data(SIMPLE_ITEMS)
120
-
121
- with pytest.raises(ValueError, match='leaf_path must be provided'):
122
- dataset.stats(cast(Any, None))
123
-
124
- with pytest.raises(ValueError, match='Leaf "\\(\'unknown\',\\)" not found in dataset'):
125
- dataset.stats(leaf_path='unknown')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_test.py DELETED
@@ -1,860 +0,0 @@
1
- """Implementation-agnostic tests of the Dataset DB API."""
2
-
3
- from typing import Iterable, Optional, cast
4
-
5
- import numpy as np
6
- import pytest
7
- from typing_extensions import override
8
-
9
- from ..schema import UUID_COLUMN, VALUE_KEY, Field, Item, RichData, field, schema
10
- from ..signals.signal import TextEmbeddingSignal, TextSignal, clear_signal_registry, register_signal
11
- from .dataset import Column, DatasetManifest, val
12
- from .dataset_test_utils import TEST_DATASET_NAME, TEST_NAMESPACE, TestDataMaker, enriched_item
13
- from .dataset_utils import lilac_embedding
14
-
15
- SIMPLE_ITEMS: list[Item] = [{
16
- UUID_COLUMN: '1',
17
- 'str': 'a',
18
- 'int': 1,
19
- 'bool': False,
20
- 'float': 3.0
21
- }, {
22
- UUID_COLUMN: '2',
23
- 'str': 'b',
24
- 'int': 2,
25
- 'bool': True,
26
- 'float': 2.0
27
- }, {
28
- UUID_COLUMN: '3',
29
- 'str': 'b',
30
- 'int': 2,
31
- 'bool': True,
32
- 'float': 1.0
33
- }]
34
-
35
- EMBEDDINGS: list[tuple[str, list[float]]] = [('hello.', [1.0, 0.0, 0.0]),
36
- ('hello2.', [1.0, 1.0, 0.0]),
37
- ('hello world.', [1.0, 1.0, 1.0]),
38
- ('hello world2.', [2.0, 1.0, 1.0])]
39
-
40
- STR_EMBEDDINGS: dict[str, list[float]] = {text: embedding for text, embedding in EMBEDDINGS}
41
-
42
-
43
- class TestEmbedding(TextEmbeddingSignal):
44
- """A test embed function."""
45
- name = 'test_embedding'
46
-
47
- @override
48
- def compute(self, data: Iterable[RichData]) -> Iterable[Item]:
49
- """Call the embedding function."""
50
- for example in data:
51
- yield [lilac_embedding(0, len(example), np.array(STR_EMBEDDINGS[cast(str, example)]))]
52
-
53
-
54
- class LengthSignal(TextSignal):
55
- name = 'length_signal'
56
-
57
- _call_count: int = 0
58
-
59
- def fields(self) -> Field:
60
- return field('int32')
61
-
62
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
63
- for text_content in data:
64
- self._call_count += 1
65
- yield len(text_content)
66
-
67
-
68
- class TestSignal(TextSignal):
69
- name = 'test_signal'
70
-
71
- @override
72
- def fields(self) -> Field:
73
- return field(fields={'len': 'int32', 'flen': 'float32'})
74
-
75
- @override
76
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
77
- return [{'len': len(text_content), 'flen': float(len(text_content))} for text_content in data]
78
-
79
-
80
- @pytest.fixture(scope='module', autouse=True)
81
- def setup_teardown() -> Iterable[None]:
82
- # Setup.
83
- register_signal(TestSignal)
84
- register_signal(LengthSignal)
85
- register_signal(SignalWithQuoteInIt)
86
- register_signal(SignalWithDoubleQuoteInIt)
87
-
88
- # Unit test runs.
89
- yield
90
-
91
- # Teardown.
92
- clear_signal_registry()
93
-
94
-
95
- def test_select_all_columns(make_test_data: TestDataMaker) -> None:
96
- dataset = make_test_data(SIMPLE_ITEMS)
97
-
98
- result = dataset.select_rows()
99
- assert list(result) == SIMPLE_ITEMS
100
-
101
-
102
- def test_select_subcols_with_dot_seperator(make_test_data: TestDataMaker) -> None:
103
- items: list[Item] = [{
104
- UUID_COLUMN: '1',
105
- 'people': [{
106
- 'name': 'A',
107
- 'address': {
108
- 'zip': 1
109
- }
110
- }, {
111
- 'name': 'B',
112
- 'address': {
113
- 'zip': 2
114
- }
115
- }]
116
- }, {
117
- UUID_COLUMN: '2',
118
- 'people': [{
119
- 'name': 'C',
120
- 'address': {
121
- 'zip': 3
122
- }
123
- }]
124
- }]
125
- dataset = make_test_data(items)
126
-
127
- result = dataset.select_rows(['people.*.name', 'people.*.address.zip'])
128
- assert list(result) == [{
129
- UUID_COLUMN: '1',
130
- 'people.*.name': ['A', 'B'],
131
- 'people.*.address.zip': [1, 2]
132
- }, {
133
- UUID_COLUMN: '2',
134
- 'people.*.name': ['C'],
135
- 'people.*.address.zip': [3]
136
- }]
137
-
138
- result = dataset.select_rows(['people.*.address.zip'], combine_columns=True)
139
- assert list(result) == [{
140
- UUID_COLUMN: '1',
141
- 'people': [{
142
- 'address': {
143
- 'zip': 1
144
- }
145
- }, {
146
- 'address': {
147
- 'zip': 2
148
- }
149
- }]
150
- }, {
151
- UUID_COLUMN: '2',
152
- 'people': [{
153
- 'address': {
154
- 'zip': 3
155
- }
156
- }]
157
- }]
158
-
159
- result = dataset.select_rows(['people'])
160
- assert list(result) == items
161
-
162
-
163
- def test_select_subcols_with_escaped_dot(make_test_data: TestDataMaker) -> None:
164
- items: list[Item] = [{
165
- UUID_COLUMN: '1',
166
- 'people.new': [{
167
- 'name': 'A'
168
- }, {
169
- 'name': 'B'
170
- }]
171
- }, {
172
- UUID_COLUMN: '2',
173
- 'people.new': [{
174
- 'name': 'C'
175
- }]
176
- }]
177
- dataset = make_test_data(items)
178
-
179
- result = dataset.select_rows(['"people.new".*.name'])
180
- assert list(result) == [{
181
- UUID_COLUMN: '1',
182
- 'people.new.*.name': ['A', 'B'],
183
- }, {
184
- UUID_COLUMN: '2',
185
- 'people.new.*.name': ['C'],
186
- }]
187
-
188
- # Escape name even though it does not need to be.
189
- result = dataset.select_rows(['"people.new".*."name"'])
190
- assert list(result) == [{
191
- UUID_COLUMN: '1',
192
- 'people.new.*.name': ['A', 'B'],
193
- }, {
194
- UUID_COLUMN: '2',
195
- 'people.new.*.name': ['C'],
196
- }]
197
-
198
-
199
- def test_select_star(make_test_data: TestDataMaker) -> None:
200
- items: list[Item] = [{
201
- UUID_COLUMN: '1',
202
- 'name': 'A',
203
- 'info': {
204
- 'age': 40
205
- }
206
- }, {
207
- UUID_COLUMN: '2',
208
- 'name': 'B',
209
- 'info': {
210
- 'age': 42
211
- }
212
- }]
213
- dataset = make_test_data(items)
214
-
215
- # Select *.
216
- result = dataset.select_rows(['*'])
217
- assert list(result) == items
218
-
219
- # Select (*,).
220
- result = dataset.select_rows([('*',)])
221
- assert list(result) == items
222
-
223
- # Select *, plus a redundant `info` column.
224
- result = dataset.select_rows(['*', 'info'])
225
- assert list(result) == [{
226
- UUID_COLUMN: '1',
227
- 'name': 'A',
228
- 'info': {
229
- 'age': 40
230
- },
231
- 'info_2': {
232
- 'age': 40
233
- },
234
- }, {
235
- UUID_COLUMN: '2',
236
- 'name': 'B',
237
- 'info': {
238
- 'age': 42
239
- },
240
- 'info_2': {
241
- 'age': 42
242
- },
243
- }]
244
-
245
- # Select * plus an inner `info.age` column.
246
- result = dataset.select_rows(['*', ('info', 'age')])
247
- assert list(result) == [{
248
- UUID_COLUMN: '1',
249
- 'name': 'A',
250
- 'info': {
251
- 'age': 40
252
- },
253
- 'info.age': 40
254
- }, {
255
- UUID_COLUMN: '2',
256
- 'name': 'B',
257
- 'info': {
258
- 'age': 42
259
- },
260
- 'info.age': 42
261
- }]
262
-
263
-
264
- def test_select_star_with_combine_cols(make_test_data: TestDataMaker) -> None:
265
- items: list[Item] = [{
266
- UUID_COLUMN: '1',
267
- 'name': 'A',
268
- 'info': {
269
- 'age': 40
270
- }
271
- }, {
272
- UUID_COLUMN: '2',
273
- 'name': 'B',
274
- 'info': {
275
- 'age': 42
276
- }
277
- }]
278
- dataset = make_test_data(items)
279
-
280
- # Select *.
281
- result = dataset.select_rows(['*'], combine_columns=True)
282
- assert list(result) == items
283
-
284
- # Select *, plus a redundant `info` column.
285
- result = dataset.select_rows(['*', 'info'], combine_columns=True)
286
- assert list(result) == items
287
-
288
- # Select * plus an inner `info.age` column.
289
- result = dataset.select_rows(['*', ('info', 'age')], combine_columns=True)
290
- assert list(result) == items
291
-
292
- # Select *, plus redundant `name`, plus a udf.
293
- udf = Column('name', signal_udf=TestSignal())
294
- result = dataset.select_rows(['*', 'name', udf], combine_columns=True)
295
-
296
- assert list(result) == [{
297
- UUID_COLUMN: '1',
298
- 'name': enriched_item('A', {'test_signal': {
299
- 'len': 1,
300
- 'flen': 1.0
301
- }}),
302
- 'info': {
303
- 'age': 40
304
- }
305
- }, {
306
- UUID_COLUMN: '2',
307
- 'name': enriched_item('B', {'test_signal': {
308
- 'len': 1,
309
- 'flen': 1.0
310
- }}),
311
- 'info': {
312
- 'age': 42
313
- }
314
- }]
315
-
316
-
317
- def test_select_ids(make_test_data: TestDataMaker) -> None:
318
- dataset = make_test_data(SIMPLE_ITEMS)
319
-
320
- result = dataset.select_rows([UUID_COLUMN])
321
-
322
- assert list(result) == [{UUID_COLUMN: '1'}, {UUID_COLUMN: '2'}, {UUID_COLUMN: '3'}]
323
-
324
-
325
- def test_select_ids_with_limit_and_offset(make_test_data: TestDataMaker) -> None:
326
- items: list[Item] = [{UUID_COLUMN: str(i)} for i in range(10, 20)]
327
- dataset = make_test_data(items)
328
-
329
- result = dataset.select_rows([UUID_COLUMN], offset=1, limit=3)
330
- assert list(result) == [{UUID_COLUMN: '11'}, {UUID_COLUMN: '12'}, {UUID_COLUMN: '13'}]
331
-
332
- result = dataset.select_rows([UUID_COLUMN], offset=7, limit=2)
333
- assert list(result) == [{UUID_COLUMN: '17'}, {UUID_COLUMN: '18'}]
334
-
335
- result = dataset.select_rows([UUID_COLUMN], offset=9, limit=200)
336
- assert list(result) == [{UUID_COLUMN: '19'}]
337
-
338
- result = dataset.select_rows([UUID_COLUMN], offset=10, limit=200)
339
- assert list(result) == []
340
-
341
-
342
- def test_columns(make_test_data: TestDataMaker) -> None:
343
- dataset = make_test_data(SIMPLE_ITEMS)
344
-
345
- result = dataset.select_rows(['str', 'float'])
346
-
347
- assert list(result) == [{
348
- UUID_COLUMN: '1',
349
- 'str': 'a',
350
- 'float': 3.0
351
- }, {
352
- UUID_COLUMN: '2',
353
- 'str': 'b',
354
- 'float': 2.0
355
- }, {
356
- UUID_COLUMN: '3',
357
- 'str': 'b',
358
- 'float': 1.0
359
- }]
360
-
361
-
362
- def test_merge_values(make_test_data: TestDataMaker) -> None:
363
- dataset = make_test_data([{
364
- UUID_COLUMN: '1',
365
- 'text': 'hello'
366
- }, {
367
- UUID_COLUMN: '2',
368
- 'text': 'everybody'
369
- }])
370
- test_signal = TestSignal()
371
- dataset.compute_signal(test_signal, 'text')
372
- length_signal = LengthSignal()
373
- dataset.compute_signal(length_signal, 'text')
374
-
375
- result = dataset.select_rows(['text'])
376
- assert list(result) == [{
377
- UUID_COLUMN: '1',
378
- 'text': enriched_item('hello', {
379
- 'length_signal': 5,
380
- 'test_signal': {
381
- 'len': 5,
382
- 'flen': 5.0
383
- }
384
- })
385
- }, {
386
- UUID_COLUMN: '2',
387
- 'text': enriched_item('everybody', {
388
- 'length_signal': 9,
389
- 'test_signal': {
390
- 'len': 9,
391
- 'flen': 9.0
392
- }
393
- }),
394
- }]
395
-
396
- # Test subselection.
397
- result = dataset.select_rows(
398
- [val('text'), ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')])
399
- assert list(result) == [{
400
- UUID_COLUMN: '1',
401
- f'text.{VALUE_KEY}': 'hello',
402
- 'text.test_signal.flen': 5.0,
403
- 'text.test_signal.len': 5
404
- }, {
405
- UUID_COLUMN: '2',
406
- f'text.{VALUE_KEY}': 'everybody',
407
- 'text.test_signal.flen': 9.0,
408
- 'text.test_signal.len': 9
409
- }]
410
-
411
- # Test subselection with combine_columns=True.
412
- result = dataset.select_rows(
413
- ['text', ('text', 'test_signal', 'flen'), ('text', 'test_signal', 'len')], combine_columns=True)
414
- assert list(result) == [{
415
- UUID_COLUMN: '1',
416
- 'text': enriched_item('hello', {
417
- 'length_signal': 5,
418
- 'test_signal': {
419
- 'len': 5,
420
- 'flen': 5.0
421
- }
422
- })
423
- }, {
424
- UUID_COLUMN: '2',
425
- 'text': enriched_item('everybody', {
426
- 'length_signal': 9,
427
- 'test_signal': {
428
- 'len': 9,
429
- 'flen': 9.0
430
- }
431
- }),
432
- }]
433
-
434
- # Test subselection with aliasing.
435
- result = dataset.select_rows(
436
- columns=[val('text'), Column(('text', 'test_signal', 'len'), alias='metadata')])
437
- assert list(result) == [{
438
- UUID_COLUMN: '1',
439
- f'text.{VALUE_KEY}': 'hello',
440
- 'metadata': 5
441
- }, {
442
- UUID_COLUMN: '2',
443
- f'text.{VALUE_KEY}': 'everybody',
444
- 'metadata': 9
445
- }]
446
-
447
- result = dataset.select_rows(columns=[Column(('text'), alias='text_enrichment')])
448
- assert list(result) == [{
449
- UUID_COLUMN: '1',
450
- 'text_enrichment': enriched_item('hello', {
451
- 'length_signal': 5,
452
- 'test_signal': {
453
- 'len': 5,
454
- 'flen': 5.0
455
- }
456
- })
457
- }, {
458
- UUID_COLUMN: '2',
459
- 'text_enrichment': enriched_item('everybody', {
460
- 'length_signal': 9,
461
- 'test_signal': {
462
- 'len': 9,
463
- 'flen': 9.0
464
- }
465
- })
466
- }]
467
-
468
-
469
- def test_merge_array_values(make_test_data: TestDataMaker) -> None:
470
- dataset = make_test_data([{
471
- UUID_COLUMN: '1',
472
- 'texts': ['hello', 'everybody']
473
- }, {
474
- UUID_COLUMN: '2',
475
- 'texts': ['a', 'bc', 'def']
476
- }])
477
-
478
- test_signal = TestSignal()
479
- dataset.compute_signal(test_signal, ('texts', '*'))
480
- length_signal = LengthSignal()
481
- dataset.compute_signal(length_signal, ('texts', '*'))
482
-
483
- assert dataset.manifest() == DatasetManifest(
484
- namespace=TEST_NAMESPACE,
485
- dataset_name=TEST_DATASET_NAME,
486
- data_schema=schema({
487
- UUID_COLUMN: 'string',
488
- 'texts': [
489
- field(
490
- 'string',
491
- fields={
492
- 'length_signal': field('int32', length_signal.dict()),
493
- 'test_signal': field(
494
- signal=test_signal.dict(), fields={
495
- 'len': 'int32',
496
- 'flen': 'float32'
497
- })
498
- })
499
- ],
500
- }),
501
- num_items=2)
502
-
503
- result = dataset.select_rows(['texts'])
504
- assert list(result) == [{
505
- UUID_COLUMN: '1',
506
- 'texts': [
507
- enriched_item('hello', {
508
- 'length_signal': 5,
509
- 'test_signal': {
510
- 'len': 5,
511
- 'flen': 5.0
512
- }
513
- }),
514
- enriched_item('everybody', {
515
- 'length_signal': 9,
516
- 'test_signal': {
517
- 'len': 9,
518
- 'flen': 9.0
519
- }
520
- })
521
- ],
522
- }, {
523
- UUID_COLUMN: '2',
524
- 'texts': [
525
- enriched_item('a', {
526
- 'length_signal': 1,
527
- 'test_signal': {
528
- 'len': 1,
529
- 'flen': 1.0
530
- }
531
- }),
532
- enriched_item('bc', {
533
- 'length_signal': 2,
534
- 'test_signal': {
535
- 'len': 2,
536
- 'flen': 2.0
537
- }
538
- }),
539
- enriched_item('def', {
540
- 'length_signal': 3,
541
- 'test_signal': {
542
- 'len': 3,
543
- 'flen': 3.0
544
- }
545
- })
546
- ],
547
- }]
548
-
549
- # Test subselection.
550
- result = dataset.select_rows(
551
- [val(('texts', '*')), ('texts', '*', 'length_signal'), ('texts', '*', 'test_signal', 'flen')])
552
- assert list(result) == [{
553
- UUID_COLUMN: '1',
554
- f'texts.*.{VALUE_KEY}': ['hello', 'everybody'],
555
- 'texts.*.test_signal.flen': [5.0, 9.0],
556
- 'texts.*.length_signal': [5, 9]
557
- }, {
558
- UUID_COLUMN: '2',
559
- f'texts.*.{VALUE_KEY}': ['a', 'bc', 'def'],
560
- 'texts.*.test_signal.flen': [1.0, 2.0, 3.0],
561
- 'texts.*.length_signal': [1, 2, 3]
562
- }]
563
-
564
-
565
- def test_combining_columns(make_test_data: TestDataMaker) -> None:
566
- dataset = make_test_data([{
567
- UUID_COLUMN: '1',
568
- 'text': 'hello',
569
- 'extra': {
570
- 'text': {
571
- 'length_signal': 5,
572
- 'test_signal': {
573
- 'len': 5,
574
- 'flen': 5.0
575
- }
576
- }
577
- }
578
- }, {
579
- UUID_COLUMN: '2',
580
- 'text': 'everybody',
581
- 'extra': {
582
- 'text': {
583
- 'length_signal': 9,
584
- 'test_signal': {
585
- 'len': 9,
586
- 'flen': 9.0
587
- }
588
- }
589
- }
590
- }])
591
-
592
- # Sub-select text and test_signal.
593
- result = dataset.select_rows(['text', ('extra', 'text', 'test_signal')], combine_columns=True)
594
- assert list(result) == [{
595
- UUID_COLUMN: '1',
596
- 'text': 'hello',
597
- 'extra': {
598
- 'text': {
599
- 'test_signal': {
600
- 'len': 5,
601
- 'flen': 5.0
602
- }
603
- }
604
- }
605
- }, {
606
- UUID_COLUMN: '2',
607
- 'text': 'everybody',
608
- 'extra': {
609
- 'text': {
610
- 'test_signal': {
611
- 'len': 9,
612
- 'flen': 9.0
613
- }
614
- }
615
- }
616
- }]
617
-
618
- # Sub-select text and length_signal.
619
- result = dataset.select_rows(['text', ('extra', 'text', 'length_signal')], combine_columns=True)
620
- assert list(result) == [{
621
- UUID_COLUMN: '1',
622
- 'text': 'hello',
623
- 'extra': {
624
- 'text': {
625
- 'length_signal': 5
626
- }
627
- }
628
- }, {
629
- UUID_COLUMN: '2',
630
- 'text': 'everybody',
631
- 'extra': {
632
- 'text': {
633
- 'length_signal': 9
634
- }
635
- }
636
- }]
637
-
638
- # Sub-select length_signal only.
639
- result = dataset.select_rows([('extra', 'text', 'length_signal')], combine_columns=True)
640
- assert list(result) == [{
641
- UUID_COLUMN: '1',
642
- 'extra': {
643
- 'text': {
644
- 'length_signal': 5
645
- }
646
- }
647
- }, {
648
- UUID_COLUMN: '2',
649
- 'extra': {
650
- 'text': {
651
- 'length_signal': 9
652
- }
653
- }
654
- }]
655
-
656
- # Aliases are ignored when combing columns.
657
- len_col = Column(('extra', 'text', 'length_signal'), alias='hello')
658
- result = dataset.select_rows([len_col], combine_columns=True)
659
- assert list(result) == [{
660
- UUID_COLUMN: '1',
661
- 'extra': {
662
- 'text': {
663
- 'length_signal': 5
664
- }
665
- }
666
- }, {
667
- UUID_COLUMN: '2',
668
- 'extra': {
669
- 'text': {
670
- 'length_signal': 9
671
- }
672
- }
673
- }]
674
-
675
- # Works with UDFs and aliases are ignored.
676
- udf_col = Column('text', alias='ignored', signal_udf=LengthSignal())
677
- result = dataset.select_rows(['text', udf_col], combine_columns=True)
678
- assert list(result) == [{
679
- UUID_COLUMN: '1',
680
- 'text': enriched_item('hello', {'length_signal': 5})
681
- }, {
682
- UUID_COLUMN: '2',
683
- 'text': enriched_item('everybody', {'length_signal': 9})
684
- }]
685
-
686
-
687
- def test_source_joined_with_named_signal(make_test_data: TestDataMaker) -> None:
688
- dataset = make_test_data(SIMPLE_ITEMS)
689
- assert dataset.manifest() == DatasetManifest(
690
- namespace=TEST_NAMESPACE,
691
- dataset_name=TEST_DATASET_NAME,
692
- data_schema=schema({
693
- UUID_COLUMN: 'string',
694
- 'str': 'string',
695
- 'int': 'int32',
696
- 'bool': 'boolean',
697
- 'float': 'float32',
698
- }),
699
- num_items=3)
700
-
701
- test_signal = TestSignal()
702
- dataset.compute_signal(test_signal, 'str')
703
-
704
- # Check the enriched dataset manifest has 'text' enriched.
705
- assert dataset.manifest() == DatasetManifest(
706
- namespace=TEST_NAMESPACE,
707
- dataset_name=TEST_DATASET_NAME,
708
- data_schema=schema({
709
- UUID_COLUMN: 'string',
710
- 'str': field(
711
- 'string',
712
- fields={
713
- 'test_signal': field(
714
- signal=test_signal.dict(), fields={
715
- 'len': 'int32',
716
- 'flen': 'float32'
717
- })
718
- }),
719
- 'int': 'int32',
720
- 'bool': 'boolean',
721
- 'float': 'float32',
722
- }),
723
- num_items=3)
724
-
725
- # Select both columns, without val() on str.
726
- result = dataset.select_rows(['str', Column(('str', 'test_signal'), alias='test_signal_on_str')])
727
-
728
- assert list(result) == [{
729
- UUID_COLUMN: '1',
730
- 'str': enriched_item('a', {'test_signal': {
731
- 'len': 1,
732
- 'flen': 1.0
733
- }}),
734
- 'test_signal_on_str': {
735
- 'len': 1,
736
- 'flen': 1.0
737
- }
738
- }, {
739
- UUID_COLUMN: '2',
740
- 'str': enriched_item('b', {'test_signal': {
741
- 'len': 1,
742
- 'flen': 1.0
743
- }}),
744
- 'test_signal_on_str': {
745
- 'len': 1,
746
- 'flen': 1.0
747
- }
748
- }, {
749
- UUID_COLUMN: '3',
750
- 'str': enriched_item('b', {'test_signal': {
751
- 'len': 1,
752
- 'flen': 1.0
753
- }}),
754
- 'test_signal_on_str': {
755
- 'len': 1,
756
- 'flen': 1.0
757
- }
758
- }]
759
-
760
- # Select both columns, with val() on str.
761
- result = dataset.select_rows(
762
- [val('str'), Column(('str', 'test_signal'), alias='test_signal_on_str')])
763
-
764
- assert list(result) == [{
765
- UUID_COLUMN: '1',
766
- f'str.{VALUE_KEY}': 'a',
767
- 'test_signal_on_str': {
768
- 'len': 1,
769
- 'flen': 1.0
770
- }
771
- }, {
772
- UUID_COLUMN: '2',
773
- f'str.{VALUE_KEY}': 'b',
774
- 'test_signal_on_str': {
775
- 'len': 1,
776
- 'flen': 1.0
777
- }
778
- }, {
779
- UUID_COLUMN: '3',
780
- f'str.{VALUE_KEY}': 'b',
781
- 'test_signal_on_str': {
782
- 'len': 1,
783
- 'flen': 1.0
784
- }
785
- }]
786
-
787
-
788
- def test_invalid_column_paths(make_test_data: TestDataMaker) -> None:
789
- dataset = make_test_data([{
790
- UUID_COLUMN: '1',
791
- 'text': enriched_item('hello', {'test_signal': {
792
- 'len': 5
793
- }}),
794
- 'text2': [
795
- enriched_item('hello', {'test_signal': {
796
- 'len': 5
797
- }}),
798
- enriched_item('hi', {'test_signal': {
799
- 'len': 2
800
- }})
801
- ],
802
- }])
803
-
804
- with pytest.raises(ValueError, match='Path part "invalid" not found in the dataset'):
805
- dataset.select_rows([('text', 'test_signal', 'invalid')])
806
-
807
- with pytest.raises(ValueError, match='Selecting a specific index of a repeated field'):
808
- dataset.select_rows([('text2', '4', 'test_signal')])
809
-
810
-
811
- def test_signal_with_quote(make_test_data: TestDataMaker) -> None:
812
- dataset = make_test_data([{
813
- UUID_COLUMN: '1',
814
- 'text': 'hello',
815
- }, {
816
- UUID_COLUMN: '2',
817
- 'text': 'world',
818
- }])
819
- dataset.compute_signal(SignalWithQuoteInIt(), 'text')
820
- dataset.compute_signal(SignalWithDoubleQuoteInIt(), 'text')
821
- result = dataset.select_rows(['text'])
822
- assert list(result) == [{
823
- UUID_COLUMN: '1',
824
- 'text': enriched_item('hello', {
825
- "test'signal": True,
826
- 'test"signal': True
827
- })
828
- }, {
829
- UUID_COLUMN: '2',
830
- 'text': enriched_item('world', {
831
- "test'signal": True,
832
- 'test"signal': True
833
- }),
834
- }]
835
-
836
-
837
- class SignalWithQuoteInIt(TextSignal):
838
- name = "test'signal"
839
-
840
- @override
841
- def fields(self) -> Field:
842
- return field('boolean')
843
-
844
- @override
845
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
846
- for d in data:
847
- yield True
848
-
849
-
850
- class SignalWithDoubleQuoteInIt(TextSignal):
851
- name = 'test"signal'
852
-
853
- @override
854
- def fields(self) -> Field:
855
- return field('boolean')
856
-
857
- @override
858
- def compute(self, data: Iterable[RichData]) -> Iterable[Optional[Item]]:
859
- for d in data:
860
- yield True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/dataset_utils.py CHANGED
@@ -73,7 +73,7 @@ def lilac_embedding(start: int, end: int, embedding: Optional[np.ndarray]) -> It
73
  Tflatten = TypeVar('Tflatten', object, np.ndarray)
74
 
75
 
76
- def _flatten(input: Union[Iterable, object], is_primitive_predicate: Callable[[object],
77
  bool]) -> Generator:
78
  """Flattens a nested iterable."""
79
  if is_primitive_predicate(input):
@@ -83,13 +83,13 @@ def _flatten(input: Union[Iterable, object], is_primitive_predicate: Callable[[o
83
  elif is_primitive(input):
84
  yield input
85
  else:
86
- for elem in cast(Iterable, input):
87
  yield from _flatten(elem, is_primitive_predicate)
88
 
89
 
90
- def flatten(input: Union[Iterable, Tflatten],
91
- is_primitive_predicate: Callable[[object], bool] = is_primitive) -> Iterable[Tflatten]:
92
- """Flattens a nested iterable.
93
 
94
  Primitives and dictionaries are not flattened. The user can also provide a predicate to determine
95
  what is a primitive.
@@ -97,7 +97,7 @@ def flatten(input: Union[Iterable, Tflatten],
97
  return _flatten(input, is_primitive_predicate)
98
 
99
 
100
- def count_primitives(input: Iterable) -> int:
101
  """Iterate through each element of the input, flattening each one, computing a count.
102
 
103
  Sum the final set of counts. This is the important iterable not to exhaust.
@@ -128,7 +128,8 @@ def _unflatten(flat_input: Iterator[list[object]],
128
  return [_unflatten(flat_input, orig_elem) for orig_elem in values]
129
 
130
 
131
- def unflatten(flat_input: Iterable, original_input: Union[Iterable, object]) -> list:
 
132
  """Unflattens a flattened iterable according to the original iterable's structure."""
133
  return cast(list, _unflatten(iter(flat_input), original_input))
134
 
@@ -234,23 +235,27 @@ def write_item_embeddings_to_disk(keys: Iterable[str], embeddings: Iterable[obje
234
  return isinstance(input, np.ndarray)
235
 
236
  flat_keys = flatten_keys(keys, embeddings, is_primitive_predicate=embedding_predicate)
 
 
237
  embedding_vectors: list[np.ndarray] = []
238
- for lilac_embedding in flatten(embeddings, is_primitive_predicate=embedding_predicate):
 
 
 
 
 
239
  # We use squeeze here because embedding functions can return outer dimensions of 1.
240
- embedding_vector = lilac_embedding[EMBEDDING_KEY].reshape(-1)
241
- if embedding_vector.ndim != 1:
242
- raise ValueError(f'Expected embeddings to be 1-dimensional, got {embedding_vector.ndim} '
243
- f'with shape {embedding_vector.shape}.')
244
- embedding_vectors.append(embedding_vector)
245
 
246
- flat_embeddings = np.array(embedding_vectors)
247
 
248
  # Write the embedding index and the ordered UUID column to disk so they can be joined later.
249
 
250
  with open_file(output_path_prefix + _EMBEDDINGS_SUFFIX, 'wb') as f:
251
- np.save(f, flat_embeddings, allow_pickle=False)
252
  with open_file(output_path_prefix + _KEYS_SUFFIX, 'wb') as f:
253
- pickle.dump(flat_keys, f)
254
 
255
  return output_path_prefix
256
 
@@ -314,34 +319,63 @@ def parquet_filename(prefix: str, shard_index: int, num_shards: int) -> str:
314
 
315
 
316
  def _flatten_keys(uuid: str, nested_input: Iterable, location: list[int],
317
- is_primitive_predicate: Callable[[object], bool]) -> list[VectorKey]:
318
- if is_primitive_predicate(nested_input):
319
- return [(uuid, *location)]
320
- elif is_primitive(nested_input):
321
- return []
322
- else:
323
- result: list[VectorKey] = []
324
- if isinstance(nested_input, dict):
325
- for value in nested_input.values():
326
- result.extend(_flatten_keys(uuid, value, location, is_primitive_predicate))
327
- else:
328
- for i, input in enumerate(nested_input):
329
- result.extend(_flatten_keys(uuid, input, [*location, i], is_primitive_predicate))
330
- return result
331
 
332
 
333
  def flatten_keys(
334
  uuids: Iterable[str],
335
  nested_input: Iterable,
336
- is_primitive_predicate: Callable[[object], bool] = is_primitive) -> list[VectorKey]:
 
337
  """Flatten the uuid keys of a nested input."""
338
- result: list[VectorKey] = []
339
  for uuid, input in zip(uuids, nested_input):
340
- result.extend(_flatten_keys(uuid, input, [], is_primitive_predicate))
341
- return result
 
 
342
 
343
 
344
  def embedding_index_filename_prefix(output_dir: str, shard_index: int, num_shards: int) -> str:
345
  """Return the filename prefix for the embedding index."""
346
  npy_filename = f'embeddings-{shard_index:05d}-of-{num_shards:05d}'
347
  return os.path.join(output_dir, npy_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  Tflatten = TypeVar('Tflatten', object, np.ndarray)
74
 
75
 
76
+ def _flatten(input: Union[Iterator, object], is_primitive_predicate: Callable[[object],
77
  bool]) -> Generator:
78
  """Flattens a nested iterable."""
79
  if is_primitive_predicate(input):
 
83
  elif is_primitive(input):
84
  yield input
85
  else:
86
+ for elem in cast(Iterator, input):
87
  yield from _flatten(elem, is_primitive_predicate)
88
 
89
 
90
+ def flatten(input: Union[Iterator, Iterable, Tflatten],
91
+ is_primitive_predicate: Callable[[object], bool] = is_primitive) -> Iterator[Tflatten]:
92
+ """Flattens a nested iterator.
93
 
94
  Primitives and dictionaries are not flattened. The user can also provide a predicate to determine
95
  what is a primitive.
 
97
  return _flatten(input, is_primitive_predicate)
98
 
99
 
100
+ def count_primitives(input: Union[Iterable, Iterator]) -> int:
101
  """Iterate through each element of the input, flattening each one, computing a count.
102
 
103
  Sum the final set of counts. This is the important iterable not to exhaust.
 
128
  return [_unflatten(flat_input, orig_elem) for orig_elem in values]
129
 
130
 
131
+ def unflatten(flat_input: Union[Iterable, Iterator], original_input: Union[Iterable,
132
+ object]) -> list:
133
  """Unflattens a flattened iterable according to the original iterable's structure."""
134
  return cast(list, _unflatten(iter(flat_input), original_input))
135
 
 
235
  return isinstance(input, np.ndarray)
236
 
237
  flat_keys = flatten_keys(keys, embeddings, is_primitive_predicate=embedding_predicate)
238
+ flat_embeddings = flatten(embeddings, is_primitive_predicate=embedding_predicate)
239
+
240
  embedding_vectors: list[np.ndarray] = []
241
+ embedding_keys: list[VectorKey] = []
242
+ for key, lilac_embedding in zip(flat_keys, flat_embeddings):
243
+ if not key or not lilac_embedding or EMBEDDING_KEY not in lilac_embedding:
244
+ # Sparse embeddings may not have an embedding for every key.
245
+ continue
246
+
247
  # We use squeeze here because embedding functions can return outer dimensions of 1.
248
+ embedding_vectors.append(lilac_embedding[EMBEDDING_KEY].reshape(-1))
249
+ embedding_keys.append(key)
 
 
 
250
 
251
+ embedding_vectors = np.array(embedding_vectors)
252
 
253
  # Write the embedding index and the ordered UUID column to disk so they can be joined later.
254
 
255
  with open_file(output_path_prefix + _EMBEDDINGS_SUFFIX, 'wb') as f:
256
+ np.save(f, embedding_vectors, allow_pickle=False)
257
  with open_file(output_path_prefix + _KEYS_SUFFIX, 'wb') as f:
258
+ pickle.dump(embedding_keys, f)
259
 
260
  return output_path_prefix
261
 
 
319
 
320
 
321
  def _flatten_keys(uuid: str, nested_input: Iterable, location: list[int],
322
+ is_primitive_predicate: Callable[[object], bool]) -> Iterator[VectorKey]:
323
+ if is_primitive_predicate(nested_input) or is_primitive(nested_input) or isinstance(
324
+ nested_input, dict):
325
+ yield (uuid, *location)
326
+ return
327
+
328
+ for i, input in enumerate(nested_input):
329
+ yield from _flatten_keys(uuid, input, [*location, i], is_primitive_predicate)
 
 
 
 
 
 
330
 
331
 
332
  def flatten_keys(
333
  uuids: Iterable[str],
334
  nested_input: Iterable,
335
+ is_primitive_predicate: Callable[[object],
336
+ bool] = is_primitive) -> Iterator[Optional[VectorKey]]:
337
  """Flatten the uuid keys of a nested input."""
 
338
  for uuid, input in zip(uuids, nested_input):
339
+ if input is None:
340
+ yield None
341
+ continue
342
+ yield from _flatten_keys(uuid, input, [], is_primitive_predicate)
343
 
344
 
345
  def embedding_index_filename_prefix(output_dir: str, shard_index: int, num_shards: int) -> str:
346
  """Return the filename prefix for the embedding index."""
347
  npy_filename = f'embeddings-{shard_index:05d}-of-{num_shards:05d}'
348
  return os.path.join(output_dir, npy_filename)
349
+
350
+
351
+ Tin = TypeVar('Tin')
352
+ Tout = TypeVar('Tout')
353
+
354
+
355
+ def sparse_to_dense_compute(
356
+ sparse_input: Iterator[Optional[Tin]],
357
+ func: Callable[[Iterable[Tin]], Iterable[Tout]]) -> Iterator[Optional[Tout]]:
358
+ """Densifies the input before calling the provided `func` and sparsifies the output."""
359
+ empty_mask: list[bool] = []
360
+
361
+ def densify(x: Iterator[Optional[Tin]]) -> Iterator[Tin]:
362
+ nonlocal empty_mask
363
+ for i, value in enumerate(x):
364
+ empty_mask.append(value is None)
365
+ if value is not None:
366
+ yield value
367
+
368
+ dense_input = densify(sparse_input)
369
+ dense_output = iter(func(dense_input))
370
+ index = 0
371
+
372
+ while True:
373
+ try:
374
+ out = next(dense_output)
375
+ yield (None if empty_mask[index] else out)
376
+ index += 1
377
+ except StopIteration:
378
+ while index < len(empty_mask):
379
+ yield None
380
+ index += 1
381
+ return
src/data/dataset_utils_test.py DELETED
@@ -1,114 +0,0 @@
1
- """Tests for dataset utils."""
2
- from ..schema import PathTuple
3
- from .dataset_utils import count_primitives, flatten, unflatten, wrap_in_dicts
4
-
5
-
6
- def test_flatten() -> None:
7
- a = [[1, 2], [[3]], [4, 5, 5]]
8
- result = list(flatten(a))
9
- assert result == [1, 2, 3, 4, 5, 5]
10
-
11
-
12
- def test_flatten_primitive() -> None:
13
- result = list(flatten('hello'))
14
- assert result == ['hello']
15
-
16
-
17
- def test_unflatten() -> None:
18
- a = [[1, 2], [[3]], [4, 5, 5]]
19
- flat_a = list(flatten(a))
20
- result = unflatten(flat_a, a)
21
- assert result == [[1, 2], [[3]], [4, 5, 5]]
22
-
23
-
24
- def test_count_nested() -> None:
25
- a = [[1, 2], [[3]], [4, 5, 6]]
26
- assert 6 == count_primitives(a)
27
-
28
-
29
- def test_wrap_in_dicts_with_spec_of_one_repeated() -> None:
30
- a = [[1, 2], [3], [4, 5, 5]]
31
- spec: list[PathTuple] = [('a', 'b', 'c'), ('d',)] # Corresponds to a.b.c.*.d.
32
- result = wrap_in_dicts(a, spec)
33
- assert result == [{
34
- 'a': {
35
- 'b': {
36
- 'c': [{
37
- 'd': 1
38
- }, {
39
- 'd': 2
40
- }]
41
- }
42
- }
43
- }, {
44
- 'a': {
45
- 'b': {
46
- 'c': [{
47
- 'd': 3
48
- }]
49
- }
50
- }
51
- }, {
52
- 'a': {
53
- 'b': {
54
- 'c': [{
55
- 'd': 4
56
- }, {
57
- 'd': 5
58
- }, {
59
- 'd': 5
60
- }]
61
- }
62
- }
63
- }]
64
-
65
-
66
- def test_wrap_in_dicts_with_spec_of_double_repeated() -> None:
67
- a = [[[1, 2], [3, 4, 5]], [[6]], [[7], [8], [9, 10]]]
68
- spec: list[PathTuple] = [('a', 'b'), tuple(), ('c',)] # Corresponds to a.b.*.*.c.
69
- result = wrap_in_dicts(a, spec)
70
- assert result == [{
71
- 'a': {
72
- 'b': [[{
73
- 'c': 1
74
- }, {
75
- 'c': 2
76
- }], [{
77
- 'c': 3
78
- }, {
79
- 'c': 4
80
- }, {
81
- 'c': 5
82
- }]]
83
- }
84
- }, {
85
- 'a': {
86
- 'b': [[{
87
- 'c': 6
88
- }]]
89
- }
90
- }, {
91
- 'a': {
92
- 'b': [[{
93
- 'c': 7
94
- }], [{
95
- 'c': 8
96
- }], [{
97
- 'c': 9
98
- }, {
99
- 'c': 10
100
- }]]
101
- }
102
- }]
103
-
104
-
105
- def test_unflatten_primitive() -> None:
106
- original = 'hello'
107
- result = unflatten(['hello'], original)
108
- assert result == 'hello'
109
-
110
-
111
- def test_unflatten_primitive_list() -> None:
112
- original = ['hello', 'world']
113
- result = unflatten(['hello', 'world'], original)
114
- assert result == ['hello', 'world']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/sources/csv_source_test.py DELETED
@@ -1,42 +0,0 @@
1
- """Tests for the CSV source."""
2
- import csv
3
- import os
4
- import pathlib
5
-
6
- from ...schema import schema
7
- from .csv_source import LINE_NUMBER_COLUMN, CSVDataset
8
- from .source import SourceSchema
9
-
10
-
11
- def test_csv(tmp_path: pathlib.Path) -> None:
12
- csv_rows = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
13
-
14
- filename = 'test-dataset.csv'
15
- filepath = os.path.join(tmp_path, filename)
16
- with open(filepath, 'w') as f:
17
- writer = csv.DictWriter(f, fieldnames=list(csv_rows[0].keys()))
18
- writer.writeheader()
19
- writer.writerows(csv_rows)
20
-
21
- source = CSVDataset(filepaths=[filepath])
22
- source.setup()
23
-
24
- source_schema = source.source_schema()
25
- assert source_schema == SourceSchema(
26
- fields=schema({
27
- LINE_NUMBER_COLUMN: 'int64',
28
- 'x': 'int64',
29
- 'y': 'string'
30
- }).fields, num_items=2)
31
-
32
- items = list(source.process())
33
-
34
- assert items == [{
35
- LINE_NUMBER_COLUMN: 0,
36
- 'x': 1,
37
- 'y': 'ten'
38
- }, {
39
- LINE_NUMBER_COLUMN: 1,
40
- 'x': 2,
41
- 'y': 'twenty'
42
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/sources/huggingface_source_test.py DELETED
@@ -1,170 +0,0 @@
1
- """Tests for the pandas source."""
2
- import os
3
- import pathlib
4
-
5
- # mypy: disable-error-code="attr-defined"
6
- from datasets import Dataset, Features, Sequence, Value
7
-
8
- from ...schema import schema
9
- from .huggingface_source import HF_SPLIT_COLUMN, HuggingFaceDataset
10
- from .source import SourceSchema
11
-
12
-
13
- def test_hf(tmp_path: pathlib.Path) -> None:
14
- dataset = Dataset.from_list([{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}])
15
-
16
- dataset_name = os.path.join(tmp_path, 'hf-test-dataset')
17
- dataset.save_to_disk(dataset_name)
18
-
19
- source = HuggingFaceDataset(dataset_name=dataset_name, load_from_disk=True)
20
-
21
- items = source.process()
22
- source.setup()
23
-
24
- source_schema = source.source_schema()
25
- assert source_schema == SourceSchema(
26
- fields=schema({
27
- HF_SPLIT_COLUMN: 'string',
28
- 'x': 'int64',
29
- 'y': 'string'
30
- }).fields, num_items=2)
31
-
32
- items = list(source.process())
33
-
34
- assert items == [{
35
- HF_SPLIT_COLUMN: 'default',
36
- 'x': 1,
37
- 'y': 'ten'
38
- }, {
39
- HF_SPLIT_COLUMN: 'default',
40
- 'x': 2,
41
- 'y': 'twenty'
42
- }]
43
-
44
-
45
- def test_hf_sequence(tmp_path: pathlib.Path) -> None:
46
- dataset = Dataset.from_list([{
47
- 'scalar': 1,
48
- 'seq': [1, 0],
49
- 'seq_dict': {
50
- 'x': [1, 2, 3],
51
- 'y': ['four', 'five', 'six']
52
- }
53
- }, {
54
- 'scalar': 2,
55
- 'seq': [2, 0],
56
- 'seq_dict': {
57
- 'x': [10, 20, 30],
58
- 'y': ['forty', 'fifty', 'sixty']
59
- }
60
- }],
61
- features=Features({
62
- 'scalar': Value(dtype='int64'),
63
- 'seq': Sequence(feature=Value(dtype='int64')),
64
- 'seq_dict': Sequence(feature={
65
- 'x': Value(dtype='int64'),
66
- 'y': Value(dtype='string')
67
- })
68
- }))
69
-
70
- dataset_name = os.path.join(tmp_path, 'hf-test-dataset')
71
- dataset.save_to_disk(dataset_name)
72
-
73
- source = HuggingFaceDataset(dataset_name=dataset_name, load_from_disk=True)
74
-
75
- items = source.process()
76
- source.setup()
77
-
78
- source_schema = source.source_schema()
79
- assert source_schema == SourceSchema(
80
- fields=schema({
81
- HF_SPLIT_COLUMN: 'string',
82
- 'scalar': 'int64',
83
- 'seq': ['int64'],
84
- 'seq_dict': {
85
- 'x': ['int64'],
86
- 'y': ['string'],
87
- },
88
- }).fields,
89
- num_items=2)
90
-
91
- items = list(source.process())
92
-
93
- assert items == [{
94
- HF_SPLIT_COLUMN: 'default',
95
- 'scalar': 1,
96
- 'seq': [1, 0],
97
- 'seq_dict': {
98
- 'x': [1, 2, 3],
99
- 'y': ['four', 'five', 'six']
100
- }
101
- }, {
102
- HF_SPLIT_COLUMN: 'default',
103
- 'scalar': 2,
104
- 'seq': [2, 0],
105
- 'seq_dict': {
106
- 'x': [10, 20, 30],
107
- 'y': ['forty', 'fifty', 'sixty']
108
- }
109
- }]
110
-
111
-
112
- def test_hf_list(tmp_path: pathlib.Path) -> None:
113
- dataset = Dataset.from_list([{
114
- 'scalar': 1,
115
- 'list': [{
116
- 'x': 1,
117
- 'y': 'two'
118
- }]
119
- }, {
120
- 'scalar': 2,
121
- 'list': [{
122
- 'x': 3,
123
- 'y': 'four'
124
- }]
125
- }],
126
- features=Features({
127
- 'scalar': Value(dtype='int64'),
128
- 'list': [{
129
- 'x': Value(dtype='int64'),
130
- 'y': Value(dtype='string')
131
- }]
132
- }))
133
-
134
- dataset_name = os.path.join(tmp_path, 'hf-test-dataset')
135
- dataset.save_to_disk(dataset_name)
136
-
137
- source = HuggingFaceDataset(dataset_name=dataset_name, load_from_disk=True)
138
-
139
- items = source.process()
140
- source.setup()
141
-
142
- source_schema = source.source_schema()
143
- assert source_schema == SourceSchema(
144
- fields=schema({
145
- HF_SPLIT_COLUMN: 'string',
146
- 'scalar': 'int64',
147
- 'list': [{
148
- 'x': 'int64',
149
- 'y': 'string',
150
- }],
151
- }).fields,
152
- num_items=2)
153
-
154
- items = list(source.process())
155
-
156
- assert items == [{
157
- HF_SPLIT_COLUMN: 'default',
158
- 'scalar': 1,
159
- 'list': [{
160
- 'x': 1,
161
- 'y': 'two'
162
- }]
163
- }, {
164
- HF_SPLIT_COLUMN: 'default',
165
- 'scalar': 2,
166
- 'list': [{
167
- 'x': 3,
168
- 'y': 'four'
169
- }]
170
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/sources/json_source_test.py DELETED
@@ -1,74 +0,0 @@
1
- """Tests for the JSON source."""
2
- import json
3
- import os
4
- import pathlib
5
-
6
- from ...schema import schema
7
- from .json_source import ROW_ID_COLUMN, JSONDataset
8
- from .source import SourceSchema
9
-
10
-
11
- def test_simple_json(tmp_path: pathlib.Path) -> None:
12
- json_records = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
13
-
14
- filename = 'test-dataset.jsonl'
15
- filepath = os.path.join(tmp_path, filename)
16
- with open(filepath, 'w') as f:
17
- f.write(json.dumps(json_records))
18
-
19
- source = JSONDataset(filepaths=[filepath])
20
- source.setup()
21
-
22
- source_schema = source.source_schema()
23
- assert source_schema == SourceSchema(
24
- fields=schema({
25
- ROW_ID_COLUMN: 'int64',
26
- 'x': 'int64',
27
- 'y': 'string'
28
- }).fields, num_items=2)
29
-
30
- items = list(source.process())
31
-
32
- assert items == [{
33
- ROW_ID_COLUMN: 0,
34
- 'x': 1,
35
- 'y': 'ten'
36
- }, {
37
- ROW_ID_COLUMN: 1,
38
- 'x': 2,
39
- 'y': 'twenty'
40
- }]
41
-
42
-
43
- def test_simple_jsonl(tmp_path: pathlib.Path) -> None:
44
- json_records = [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
45
- json_lines = [json.dumps(record) + '\n' for record in json_records]
46
-
47
- filename = 'test-dataset.jsonl'
48
- filepath = os.path.join(tmp_path, filename)
49
- with open(filepath, 'w') as f:
50
- f.writelines(json_lines)
51
-
52
- source = JSONDataset(dataset_name='test_dataset', filepaths=[filepath])
53
- source.setup()
54
-
55
- source_schema = source.source_schema()
56
-
57
- assert source_schema == SourceSchema(
58
- fields=schema({
59
- ROW_ID_COLUMN: 'int64',
60
- 'x': 'int64',
61
- 'y': 'string'
62
- }).fields, num_items=2)
63
-
64
- items = list(source.process())
65
-
66
- assert items == [{
67
- ROW_ID_COLUMN: 0,
68
- 'x': 1,
69
- 'y': 'ten'
70
- }, {
71
- ROW_ID_COLUMN: 1,
72
- 'x': 2,
73
- 'y': 'twenty'
74
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/sources/pandas_source_test.py DELETED
@@ -1,91 +0,0 @@
1
- """Tests for the pandas source."""
2
-
3
- import pandas as pd
4
-
5
- from ...schema import schema
6
- from .pandas_source import PANDAS_INDEX_COLUMN, PandasDataset
7
- from .source import SourceSchema
8
-
9
-
10
- def test_simple_dataframe() -> None:
11
- df = pd.DataFrame.from_records([{
12
- 'name': 'a',
13
- 'age': 1
14
- }, {
15
- 'name': 'b',
16
- 'age': 2
17
- }, {
18
- 'name': 'c',
19
- 'age': 3
20
- }])
21
-
22
- source = PandasDataset(df)
23
- source.setup()
24
-
25
- source_schema = source.source_schema()
26
- assert source_schema == SourceSchema(
27
- fields=schema({
28
- PANDAS_INDEX_COLUMN: 'int64',
29
- 'name': 'string',
30
- 'age': 'int64'
31
- }).fields,
32
- num_items=3)
33
-
34
- items = list(source.process())
35
-
36
- assert items == [{
37
- PANDAS_INDEX_COLUMN: 0,
38
- 'name': 'a',
39
- 'age': 1
40
- }, {
41
- PANDAS_INDEX_COLUMN: 1,
42
- 'name': 'b',
43
- 'age': 2
44
- }, {
45
- PANDAS_INDEX_COLUMN: 2,
46
- 'name': 'c',
47
- 'age': 3
48
- }]
49
-
50
-
51
- def test_simple_dataframe_with_index() -> None:
52
- df = pd.DataFrame.from_records([{
53
- 'name': 'a',
54
- 'age': 1
55
- }, {
56
- 'name': 'b',
57
- 'age': 2
58
- }, {
59
- 'name': 'c',
60
- 'age': 3
61
- }],
62
- index=['id1', 'id2', 'id3'])
63
-
64
- source = PandasDataset(df)
65
- source.setup()
66
-
67
- source_schema = source.source_schema()
68
- assert source_schema == SourceSchema(
69
- fields=schema({
70
- PANDAS_INDEX_COLUMN: 'string',
71
- 'name': 'string',
72
- 'age': 'int64'
73
- }).fields,
74
- num_items=3)
75
-
76
- items = list(source.process())
77
-
78
- # The PANDAS_INDEX_COLUMN aligns with the pandas index.
79
- assert items == [{
80
- PANDAS_INDEX_COLUMN: 'id1',
81
- 'name': 'a',
82
- 'age': 1
83
- }, {
84
- PANDAS_INDEX_COLUMN: 'id2',
85
- 'name': 'b',
86
- 'age': 2
87
- }, {
88
- PANDAS_INDEX_COLUMN: 'id3',
89
- 'name': 'c',
90
- 'age': 3
91
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data/sources/source_registry_test.py DELETED
@@ -1,55 +0,0 @@
1
- """A source to compute semantic search for a document."""
2
- from typing import Iterable, cast
3
-
4
- import pytest
5
- from typing_extensions import override
6
-
7
- from ...schema import Item
8
- from .source import Source, SourceSchema
9
- from .source_registry import clear_source_registry, get_source_cls, register_source, resolve_source
10
-
11
-
12
- class TestSource(Source):
13
- """A test source."""
14
- name = 'test_source'
15
-
16
- @override
17
- def setup(self) -> None:
18
- pass
19
-
20
- @override
21
- def source_schema(self) -> SourceSchema:
22
- """Return the source schema."""
23
- return cast(SourceSchema, None)
24
-
25
- @override
26
- def process(self) -> Iterable[Item]:
27
- yield None
28
-
29
-
30
- @pytest.fixture(scope='module', autouse=True)
31
- def setup_teardown() -> Iterable[None]:
32
- # Setup.
33
- register_source(TestSource)
34
-
35
- # Unit test runs.
36
- yield
37
-
38
- # Teardown.
39
- clear_source_registry()
40
-
41
-
42
- def test_get_source_cls() -> None:
43
- """Test getting a source."""
44
- assert TestSource == get_source_cls('test_source')
45
-
46
-
47
- def test_resolve_source() -> None:
48
- """Test resolving a source."""
49
- test_source = TestSource()
50
-
51
- # sources pass through.
52
- assert resolve_source(test_source) == test_source
53
-
54
- # Dicts resolve to the base class.
55
- assert resolve_source(test_source.dict()) == test_source
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_loader_test.py DELETED
@@ -1,74 +0,0 @@
1
- """Tests for data_loader.py."""
2
-
3
- import os
4
- import pathlib
5
- import uuid
6
- from typing import Iterable
7
-
8
- from pytest_mock import MockerFixture
9
- from typing_extensions import override
10
-
11
- from .data.dataset_duckdb import read_source_manifest
12
- from .data.dataset_utils import parquet_filename
13
- from .data.sources.source import Source, SourceSchema
14
- from .data_loader import process_source
15
- from .schema import PARQUET_FILENAME_PREFIX, UUID_COLUMN, Item, SourceManifest, schema
16
- from .test_utils import fake_uuid, read_items
17
- from .utils import DATASETS_DIR_NAME
18
-
19
-
20
- class TestSource(Source):
21
- """A test source."""
22
- name = 'test_source'
23
-
24
- @override
25
- def setup(self) -> None:
26
- pass
27
-
28
- @override
29
- def source_schema(self) -> SourceSchema:
30
- """Return the source schema."""
31
- return SourceSchema(fields=schema({'x': 'int64', 'y': 'string'}).fields, num_items=2)
32
-
33
- @override
34
- def process(self) -> Iterable[Item]:
35
- return [{'x': 1, 'y': 'ten'}, {'x': 2, 'y': 'twenty'}]
36
-
37
-
38
- def test_data_loader(tmp_path: pathlib.Path, mocker: MockerFixture) -> None:
39
- mock_uuid = mocker.patch.object(uuid, 'uuid4', autospec=True)
40
- mock_uuid.side_effect = [fake_uuid(b'1'), fake_uuid(b'2')]
41
-
42
- source = TestSource()
43
- setup_mock = mocker.spy(TestSource, 'setup')
44
-
45
- output_dir, num_items = process_source(tmp_path, 'test_namespace', 'test_dataset', source)
46
-
47
- assert setup_mock.call_count == 1
48
-
49
- assert output_dir == os.path.join(tmp_path, DATASETS_DIR_NAME, 'test_namespace', 'test_dataset')
50
- assert num_items == 2
51
-
52
- source_manifest = read_source_manifest(output_dir)
53
-
54
- assert source_manifest == SourceManifest(
55
- files=[parquet_filename(PARQUET_FILENAME_PREFIX, 0, 1)],
56
- data_schema=schema({
57
- # UUID_COLUMN is generated by the data loader.
58
- UUID_COLUMN: 'string',
59
- 'x': 'int64',
60
- 'y': 'string'
61
- }),
62
- )
63
-
64
- items = read_items(output_dir, source_manifest.files, source_manifest.data_schema)
65
-
66
- assert items == [{
67
- UUID_COLUMN: fake_uuid(b'1').hex,
68
- 'x': 1,
69
- 'y': 'ten'
70
- }, {
71
- UUID_COLUMN: fake_uuid(b'2').hex,
72
- 'x': 2,
73
- 'y': 'twenty'
74
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/embeddings/embedding.py CHANGED
@@ -57,7 +57,7 @@ def compute_split_embeddings(docs: Iterable[str],
57
  pool = ThreadPoolExecutor()
58
 
59
  def _splitter(doc: str) -> list[TextChunk]:
60
- if doc is None:
61
  return []
62
  if split_fn:
63
  return split_fn(doc)
@@ -65,15 +65,19 @@ def compute_split_embeddings(docs: Iterable[str],
65
  # Return a single chunk that spans the entire document.
66
  return [(doc, (0, len(doc)))]
67
 
 
 
68
  def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
69
  """Split a batch of documents into chunks and yield them."""
 
70
  for i, doc in enumerate(docs):
71
- chunks = _splitter(doc) or [cast(TextChunk, ('', (0, 0)))]
 
72
  for chunk in chunks:
73
  yield (i, chunk)
74
 
75
  doc_chunks = _flat_split_batch_docs(docs)
76
- items_to_yield: list[Item] = []
77
  current_index = 0
78
 
79
  mega_batch_size = batch_size * num_parallel_requests
@@ -81,19 +85,27 @@ def compute_split_embeddings(docs: Iterable[str],
81
  for batch in chunks(doc_chunks, mega_batch_size):
82
  texts = [text for _, (text, _) in batch]
83
  embeddings: list[np.ndarray] = []
 
84
  for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
85
  embeddings.extend(x)
86
  matrix = normalize(np.array(embeddings)).astype(np.float16)
87
  # np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
88
  embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
89
  for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
 
90
  if index == current_index:
 
 
91
  items_to_yield.append(lilac_embedding(start, end, embedding))
92
  else:
93
  yield items_to_yield
 
 
 
 
94
  items_to_yield = [lilac_embedding(start, end, embedding)]
95
- current_index = index
96
 
97
- # Yield the last batch.
98
- if items_to_yield:
99
  yield items_to_yield
 
 
 
57
  pool = ThreadPoolExecutor()
58
 
59
  def _splitter(doc: str) -> list[TextChunk]:
60
+ if not doc:
61
  return []
62
  if split_fn:
63
  return split_fn(doc)
 
65
  # Return a single chunk that spans the entire document.
66
  return [(doc, (0, len(doc)))]
67
 
68
+ num_docs = 0
69
+
70
  def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
71
  """Split a batch of documents into chunks and yield them."""
72
+ nonlocal num_docs
73
  for i, doc in enumerate(docs):
74
+ num_docs += 1
75
+ chunks = _splitter(doc)
76
  for chunk in chunks:
77
  yield (i, chunk)
78
 
79
  doc_chunks = _flat_split_batch_docs(docs)
80
+ items_to_yield: Optional[list[Item]] = None
81
  current_index = 0
82
 
83
  mega_batch_size = batch_size * num_parallel_requests
 
85
  for batch in chunks(doc_chunks, mega_batch_size):
86
  texts = [text for _, (text, _) in batch]
87
  embeddings: list[np.ndarray] = []
88
+
89
  for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
90
  embeddings.extend(x)
91
  matrix = normalize(np.array(embeddings)).astype(np.float16)
92
  # np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
93
  embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
94
  for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
95
+ embedding = embedding.reshape(-1)
96
  if index == current_index:
97
+ if items_to_yield is None:
98
+ items_to_yield = []
99
  items_to_yield.append(lilac_embedding(start, end, embedding))
100
  else:
101
  yield items_to_yield
102
+ current_index += 1
103
+ while current_index < index:
104
+ yield None
105
+ current_index += 1
106
  items_to_yield = [lilac_embedding(start, end, embedding)]
 
107
 
108
+ while current_index < num_docs:
 
109
  yield items_to_yield
110
+ items_to_yield = None
111
+ current_index += 1