diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 168d54c59..f86800a25 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -18,7 +18,7 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - working-directory: pgml-sdks/rust/pgml/javascript + working-directory: pgml-sdks/pgml/javascript steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 @@ -34,14 +34,11 @@ jobs: run: | npm i npm run build-release - mv index.node ${{ matrix.neon-out-name }} - - name: Display output files - run: ls -R - name: Upload built .node file uses: actions/upload-artifact@v3 with: name: node-artifacts - path: pgml-sdks/rust/pgml/javascript/${{ matrix.neon-out-name }} + path: pgml-sdks/pgml/javascript/dist/${{ matrix.neon-out-name }} retention-days: 1 # publish-javascript-sdk: # needs: build-javascript-sdk diff --git a/.github/workflows/python-sdk.yml b/.github/workflows/python-sdk.yml index fc562778b..e8d042fff 100644 --- a/.github/workflows/python-sdk.yml +++ b/.github/workflows/python-sdk.yml @@ -14,7 +14,7 @@ jobs: runs-on: ${{ matrix.os }} defaults: run: - working-directory: pgml-sdks/rust/pgml + working-directory: pgml-sdks/pgml steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 @@ -62,7 +62,7 @@ jobs: runs-on: macos-latest defaults: run: - working-directory: pgml-sdks/rust/pgml + working-directory: pgml-sdks/pgml steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 @@ -101,7 +101,7 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] defaults: run: - working-directory: pgml-sdks\rust\pgml + working-directory: pgml-sdks\pgml steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index db7b5d11f..dc5b7dada 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -2,17 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "ahash" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] - [[package]] name = "ahash" version = "0.8.3" @@ -20,6 +9,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", + "getrandom", "once_cell", "version_check", ] @@ -73,9 +63,9 @@ dependencies = [ [[package]] name = "atoi" -version = "1.0.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" dependencies = [ "num-traits", ] @@ -88,15 +78,15 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.13.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] -name = "base64" -version = "0.21.2" +name = "base64ct" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bitflags" @@ -104,6 +94,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +dependencies = [ + "serde", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -171,6 +170,12 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "const-oid" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" + [[package]] name = "core-foundation" version = "0.9.3" @@ -275,6 +280,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "der" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -282,30 +298,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - [[package]] name = "dotenvy" version = "0.15.7" @@ -317,6 +314,9 @@ name = "either" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -333,6 +333,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.1" @@ -354,6 +360,17 @@ dependencies = [ "libc", ] +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -369,6 +386,18 @@ dependencies = [ "instant", ] +[[package]] +name = "flume" +version = "0.10.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" +dependencies = [ + "futures-core", + "futures-sink", + "pin-project", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -443,13 +472,13 @@ dependencies = [ [[package]] name = "futures-intrusive" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.11.2", + "parking_lot", ] [[package]] @@ -532,7 +561,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -551,7 +580,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" dependencies = [ - "ahash 0.8.3", + "ahash", "allocator-api2", ] @@ -612,6 +641,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" +dependencies = [ + "windows-sys 0.48.0", +] + [[package]] name = "http" version = "0.2.9" @@ -732,6 +770,16 @@ dependencies = [ "hashbrown 0.12.3", ] +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", +] + [[package]] name = "indicatif" version = "0.17.6" @@ -751,6 +799,17 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +[[package]] +name = "inherent" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "instant" version = "0.1.12" @@ -806,6 +865,9 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin 0.5.2", +] [[package]] name = "libc" @@ -823,6 +885,23 @@ dependencies = [ "winapi", ] +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + +[[package]] +name = "libsqlite3-sys" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.3.8" @@ -977,6 +1056,44 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.15" @@ -984,6 +1101,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1014,7 +1132,7 @@ version = "0.10.55" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "foreign-types", "libc", @@ -1040,6 +1158,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-src" +version = "111.26.0+1.1.1u" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc62c9f12b22b8f5208c23a7200a442b2e5999f8bdf80233852122b5a4f6f37" +dependencies = [ + "cc", +] + [[package]] name = "openssl-sys" version = "0.9.90" @@ -1048,6 +1175,7 @@ checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" dependencies = [ "cc", "libc", + "openssl-src", "pkg-config", "vcpkg", ] @@ -1058,17 +1186,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -1076,21 +1193,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -1101,7 +1204,7 @@ checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.3.5", + "redox_syscall", "smallvec", "windows-targets 0.48.0", ] @@ -1112,6 +1215,15 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.0" @@ -1120,7 +1232,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.9.0" +version = "0.9.1" dependencies = [ "anyhow", "async-trait", @@ -1146,6 +1258,26 @@ dependencies = [ "uuid", ] +[[package]] +name = "pin-project" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -1158,6 +1290,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.27" @@ -1195,7 +1348,7 @@ dependencies = [ "indoc", "libc", "memoffset", - "parking_lot 0.12.1", + "parking_lot", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -1309,33 +1462,13 @@ dependencies = [ "getrandom", ] -[[package]] -name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags", -] - [[package]] name = "redox_syscall" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags", -] - -[[package]] -name = "redox_users" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" -dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", + "bitflags 1.3.2", ] [[package]] @@ -1361,7 +1494,7 @@ version = "0.11.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" dependencies = [ - "base64 0.21.2", + "base64", "bytes", "encoding_rs", "futures-core", @@ -1401,12 +1534,34 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi", ] +[[package]] +name = "rsa" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab43bb47d23c1a631b4b680199a45255dce26fa9ab2fa902581f624ff13e6a8" +dependencies = [ + "byteorder", + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-iter", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust_bridge" version = "0.1.0" @@ -1438,7 +1593,7 @@ version = "0.37.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06" dependencies = [ - "bitflags", + "bitflags 1.3.2", "errno", "io-lifetimes", "libc", @@ -1448,14 +1603,13 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.8" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" dependencies = [ - "log", "ring", + "rustls-webpki", "sct", - "webpki", ] [[package]] @@ -1464,7 +1618,17 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.2", + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261e9e0888cba427c3316e6322805653c9425240b6fd96cee7cb671ab70ab8d0" +dependencies = [ + "ring", + "untrusted", ] [[package]] @@ -1500,10 +1664,11 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.28.5" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbab99b8cd878ab7786157b7eb8df96333a6807cc6e45e8888c85b51534b401a" +checksum = "28c05a5bf6403834be253489bbe95fa9b1e5486bc843b61f60d26b5c9c1e244b" dependencies = [ + "inherent", "sea-query-attr", "sea-query-derive", "serde_json", @@ -1523,9 +1688,9 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.3.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cea85029985b40dfbf18318d85fe985c04db7c1b4e5e8e0a0a0cdff5f1e30f9" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", @@ -1534,9 +1699,9 @@ dependencies = [ [[package]] name = "sea-query-derive" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63f62030c60f3a691f5fe251713b4e220b306e50a71e1d6f9cce1f24bb781978" +checksum = "bd78f2e0ee8e537e9195d1049b752e0433e2cac125426bccb7b5c3e508096117" dependencies = [ "heck", "proc-macro2", @@ -1551,7 +1716,7 @@ version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", @@ -1657,6 +1822,16 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signature" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e1788eed21689f9cf370582dfc467ef36ed9c707f073528ddafa8d83e3b8500" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "slab" version = "0.4.8" @@ -1688,6 +1863,25 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1e996ef02c474957d681f1b05213dfb0abab947b446a62d37770b23500184a" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sqlformat" version = "0.2.1" @@ -1701,99 +1895,212 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.6.3" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" +checksum = "8e58421b6bc416714d5115a2ca953718f6c621a51b68e4f4922aea5a4391a721" dependencies = [ "sqlx-core", "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", ] [[package]] name = "sqlx-core" -version = "0.6.3" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" +checksum = "dd4cef4251aabbae751a3710927945901ee1d97ee96d757f6880ebb9a79bfd53" dependencies = [ - "ahash 0.7.6", + "ahash", "atoi", - "base64 0.13.1", - "bitflags", "byteorder", "bytes", "chrono", "crc", "crossbeam-queue", - "dirs", "dotenvy", "either", "event-listener", "futures-channel", "futures-core", "futures-intrusive", + "futures-io", "futures-util", "hashlink", "hex", - "hkdf", - "hmac", - "indexmap", - "itoa", - "libc", + "indexmap 2.0.0", "log", - "md-5", "memchr", "once_cell", "paste", "percent-encoding", - "rand", "rustls", "rustls-pemfile", "serde", "serde_json", - "sha1", "sha2", "smallvec", "sqlformat", - "sqlx-rt", - "stringprep", "thiserror", "time 0.3.22", + "tokio", "tokio-stream", + "tracing", "url", "uuid", "webpki-roots", - "whoami", ] [[package]] name = "sqlx-macros" -version = "0.6.3" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "208e3165167afd7f3881b16c1ef3f2af69fa75980897aac8874a0696516d12c2" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 1.0.109", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" +checksum = "8a4a8336d278c62231d87f24e8a7a74898156e34c1c18942857be2acb29c7dfc" dependencies = [ "dotenvy", "either", "heck", + "hex", "once_cell", "proc-macro2", "quote", + "serde", "serde_json", "sha2", "sqlx-core", - "sqlx-rt", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", "syn 1.0.109", + "tempfile", + "tokio", "url", ] [[package]] -name = "sqlx-rt" -version = "0.6.3" +name = "sqlx-mysql" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" +checksum = "8ca69bf415b93b60b80dc8fda3cb4ef52b2336614d8da2de5456cc942a110482" dependencies = [ + "atoi", + "base64", + "bitflags 2.4.0", + "byteorder", + "bytes", + "chrono", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", "once_cell", - "tokio", - "tokio-rustls", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time 0.3.22", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0db2df1b8731c3651e204629dd55e52adbae0462fa1bdcbed56a2302c18181e" +dependencies = [ + "atoi", + "base64", + "bitflags 2.4.0", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time 0.3.22", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4c21bf34c7cae5b283efb3ac1bcc7670df7561124dc2f8bdc0b59be40f79a2" +dependencies = [ + "atoi", + "chrono", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "time 0.3.22", + "tracing", + "url", + "uuid", ] [[package]] @@ -1866,7 +2173,7 @@ dependencies = [ "autocfg", "cfg-if", "fastrand", - "redox_syscall 0.3.5", + "redox_syscall", "rustix", "windows-sys 0.48.0", ] @@ -1992,17 +2299,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -2041,6 +2337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2304,23 +2601,13 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "webpki-roots" -version = "0.22.6" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" +checksum = "b291546d5d9d1eab74f069c77749f2cb8504a12caa20f0f2de93ddbf6f411888" dependencies = [ - "webpki", + "rustls-webpki", ] [[package]] @@ -2328,10 +2615,6 @@ name = "whoami" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c70234412ca409cc04e864e89523cb0fc37f5e1344ebed5a3ebf4192b6b9f68" -dependencies = [ - "wasm-bindgen", - "web-sys", -] [[package]] name = "winapi" @@ -2504,3 +2787,9 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ "winapi", ] + +[[package]] +name = "zeroize" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index c854cb61b..b3d15786a 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "0.9.0" +version = "0.9.1" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" @@ -15,7 +15,7 @@ crate-type = ["lib", "cdylib"] [dependencies] rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} -sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid", "chrono"] } +sqlx = { version = "0.7", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid", "chrono"] } serde_json = "1.0.9" anyhow = "1.0.9" tokio = { version = "1.28.2", features = [ "macros" ] } @@ -26,10 +26,10 @@ neon = { version = "0.10", optional = true, default-features = false, features = itertools = "0.10.5" uuid = {version = "1.3.3", features = ["v4", "serde"] } md5 = "0.7.0" -sea-query = { version = "0.28.5", features = ["attr", "thread-safe", "with-json", "postgres-array"] } -sea-query-binder = { version = "0.3.1", features = ["sqlx-postgres", "with-json", "postgres-array"] } +sea-query = { version = "0.30.1", features = ["attr", "thread-safe", "with-json", "postgres-array"] } +sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-json", "postgres-array"] } regex = "1.8.4" -reqwest = { version = "0.11", features = ["json"] } +reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] } async-trait = "0.1.71" tracing = { version = "0.1.37" } tracing-subscriber = { version = "0.3.17", features = ["json"] } diff --git a/pgml-sdks/pgml/javascript/README.md b/pgml-sdks/pgml/javascript/README.md index 77a687833..de4acede9 100644 --- a/pgml-sdks/pgml/javascript/README.md +++ b/pgml-sdks/pgml/javascript/README.md @@ -208,9 +208,11 @@ const collection = pgml.newCollection("test_collection", CUSTOM_DATABASE_URL) ### Upserting Documents -Documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. +The `upsert_documents` method can be used to insert new documents and update existing documents. -**Upsert documents with metadata** +New documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. + +**Upsert new documents with metadata** ```javascript const documents = [ { @@ -228,6 +230,98 @@ const collection = pgml.newCollection("test_collection") await collection.upsert_documents(documents) ``` +Document metadata can be updated by upserting the document without the `text` key. + +**Update document metadata** +```javascript +documents = [ + { + id: "Document 1", + random_key: "this will be NEW metadata for the document" + }, + { + id: "Document 2", + random_key: "this will be NEW metadata for the document" + } +] +collection = pgml.newCollection("test_collection") +await collection.upsert_documents(documents) +``` + +### Getting Documents + +Documents can be retrieved using the `get_documents` method on the collection object + +**Get the first 100 documents** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ limit: 100 }) +``` + +#### Pagination + +The JavaScript SDK supports limit-offset pagination and keyset pagination + +**Limit-Offset pagination** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ limit: 100, offset: 10 }) +``` + +**Keyset pagination** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ limit: 100, last_row_id: 10 }) +``` + +The `last_row_id` can be taken from the `row_id` field in the returned document's dictionary. + +#### Filtering + +Metadata and full text filtering are supported just like they are in vector recall. + +**Metadata and full text filtering** +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.get_documents({ + limit: 100, + offset: 10, + filter: { + metadata: { + id: { + $eq: 1 + } + }, + full_text_search: { + configuration: "english", + text: "Some full text query" + } + } +}) + +``` + +### Deleting Documents + +Documents can be deleted with the `delete_documents` method on the collection object. + +Metadata and full text filtering are supported just like they are in vector recall. + +```javascript +collection = pgml.newCollection("test_collection") +documents = await collection.delete_documents({ + metadata: { + id: { + $eq: 1 + } + }, + full_text_search: { + configuration: "english", + text: "Some full text query" + } +}) +``` + ### Searching Collections The JavaScript SDK is specifically designed to provide powerful, flexible vector search. @@ -326,7 +420,7 @@ const results = await collection.query() .fetch_all() ``` -The above query would filter out all documents that do not have a key `special` with a value `True` or (have a key `uuid` equal to 1 and a key `index` less than 100). +The above query would filter out all documents that do not have a key `special` with a value `true` or (have a key `uuid` equal to 1 and a key `index` less than 100). #### Full Text Filtering @@ -418,7 +512,7 @@ const model = pgml.newModel() const splitter = pgml.newSplitter() const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { "full_text_search": { - active: True, + active: true, configuration: "english" } }) diff --git a/pgml-sdks/pgml/javascript/build.js b/pgml-sdks/pgml/javascript/build.js new file mode 100644 index 000000000..30fe55bfa --- /dev/null +++ b/pgml-sdks/pgml/javascript/build.js @@ -0,0 +1,48 @@ +const os = require("os"); +const { exec } = require("node:child_process"); + +const type = os.type(); +const arch = os.arch(); + +const set_name = (type, arch) => { + if (type == "Darwin" && arch == "x64") { + return "x86_64-apple-darwin-index.node"; + } else if (type == "Darwin" && arch == "arm64") { + return "aarch64-apple-darwin-index.node"; + } else if ((type == "Windows" || type == "Windows_NT") && arch == "x64") { + return "x86_64-pc-windows-gnu-index.node"; + } else if (type == "Linux" && arch == "x64") { + return "x86_64-unknown-linux-gnu-index.node"; + } else if (type == "Linux" && arch == "arm64") { + return "aarch64-unknown-linux-gnu-index.node"; + } else { + console.log("UNSUPPORTED TYPE OR ARCH:", type, arch); + process.exit(1); + } +}; + +let name = set_name(type, arch); + +let args = process.argv.slice(2); +let release = args.includes("--release"); + +let shell_args = + type == "Windows" || type == "Windows_NT" ? { shell: "powershell.exe" } : {}; + +exec( + ` + rm -r dist; + mkdir dist; + npx cargo-cp-artifact -nc "${name}" -- cargo build --message-format=json-render-diagnostics -F javascript ${release ? "--release" : ""}; + mv ${name} dist; + `, + shell_args, + (err, stdout, stderr) => { + if (err) { + console.log("ERR:", err); + } else { + console.log("STDOUT:", stdout); + console.log("STDERR:", stderr); + } + }, +); diff --git a/pgml-sdks/pgml/javascript/index.js b/pgml-sdks/pgml/javascript/index.js index 5ebc5b4d3..47ab75a8e 100644 --- a/pgml-sdks/pgml/javascript/index.js +++ b/pgml-sdks/pgml/javascript/index.js @@ -3,26 +3,22 @@ const os = require("os") const type = os.type() const arch = os.arch() -try { - const pgml = require("./index.node") +if (type == "Darwin" && arch == "x64") { + const pgml = require("./dist/x86_64-apple-darwin-index.node") module.exports = pgml -} catch (e) { - if (type == "Darwin" && arch == "x64") { - const pgml = require("./dist/x86_64-apple-darwin-index.node") - module.exports = pgml - } else if (type == "Darwin" && arch == "arm64") { - const pgml = require("./dist/aarch64-apple-darwin-index.node") - module.exports = pgml - } else if (type == "Windows" && arch == "x64") { - const pgml = require("./dist/x86_64-pc-windows-gnu-index.node") - module.exports = pgml - } else if (type == "Linux" && arch == "x64") { - const pgml = require("./dist/x86_64-unknown-linux-gnu-index.node") - module.exports = pgml - } else if (type == "Linux" && arch == "arm64") { - const pgml = require("./dist/aarch64-unknown-linux-gnu-index.node") - module.exports = pgml - } else { - console.log("UNSUPPORTED TYPE OR ARCH:", type, arch) - } +} else if (type == "Darwin" && arch == "arm64") { + const pgml = require("./dist/aarch64-apple-darwin-index.node") + module.exports = pgml +} else if ((type == "Windows" || type == "Windows_NT") && arch == "x64") { + const pgml = require("./dist/x86_64-pc-windows-gnu-index.node") + module.exports = pgml +} else if (type == "Linux" && arch == "x64") { + const pgml = require("./dist/x86_64-unknown-linux-gnu-index.node") + module.exports = pgml +} else if (type == "Linux" && arch == "arm64") { + const pgml = require("./dist/aarch64-unknown-linux-gnu-index.node") + module.exports = pgml +} else { + console.log("UNSUPPORTED TYPE OR ARCH:", type, arch) + process.exit(1); } diff --git a/pgml-sdks/pgml/javascript/package.json b/pgml-sdks/pgml/javascript/package.json index 771b8c24e..551b7156d 100644 --- a/pgml-sdks/pgml/javascript/package.json +++ b/pgml-sdks/pgml/javascript/package.json @@ -1,13 +1,17 @@ { "name": "pgml", - "version": "0.9.0", + "version": "0.9.1", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", - "keywords": ["postgres", "machine learning", "vector databases", "embeddings"], + "keywords": [ + "postgres", + "machine learning", + "vector databases", + "embeddings" + ], "main": "index.js", "scripts": { - "build": "cargo-cp-artifact -nc index.node -- cargo build --message-format=json-render-diagnostics -F javascript", - "build-debug": "npm run build --", - "build-release": "npm run build -- --release" + "build": "node build.js", + "build-release": "node build.js --release" }, "author": { "name": "PostgresML", diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 5e5b76061..f4895edf4 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -21,6 +21,7 @@ const generate_dummy_documents = (count: number) => { project: "a10", uuid: i * 10, floating_uuid: i * 1.1, + test: null, name: `Test Document ${i}`, }); } @@ -156,3 +157,66 @@ it("pipeline to dict", async () => { expect(pipeline_dict["name"]).toBe("test_j_p_ptd_0"); await collection.archive(); }); + +/////////////////////////////////////////////////// +// Test document related functions //////////////// +/////////////////////////////////////////////////// + +it("can upsert and get documents", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline("test_p_p_cuagd_0", model, splitter, { + full_text_search: { active: true, configuration: "english" }, + }); + let collection = pgml.newCollection("test_p_c_cuagd_1"); + await collection.add_pipeline(pipeline); + await collection.upsert_documents(generate_dummy_documents(10)); + + let documents = await collection.get_documents(); + expect(documents).toHaveLength(10); + + documents = await collection.get_documents({ + offset: 1, + limit: 2, + filter: { metadata: { id: { $gt: 0 } } }, + }); + expect(documents).toHaveLength(2); + expect(documents[0]["document"]["id"]).toBe(2); + let last_row_id = documents[1]["row_id"]; + + documents = await collection.get_documents({ + filter: { + metadata: { id: { $gt: 3 } }, + full_text_search: { configuration: "english", text: "4" }, + }, + last_row_id: last_row_id, + }); + expect(documents).toHaveLength(1); + expect(documents[0]["document"]["id"]).toBe(4); + + await collection.archive(); +}); + +it("can delete documents", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline( + "test_p_p_cdd_0", + model, + splitter, + + { full_text_search: { active: true, configuration: "english" } }, + ); + let collection = pgml.newCollection("test_p_c_cdd_2"); + await collection.add_pipeline(pipeline); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.delete_documents({ + metadata: { id: { $gte: 0 } }, + full_text_search: { configuration: "english", text: "0" }, + }); + let documents = await collection.get_documents(); + expect(documents).toHaveLength(2); + expect(documents[0]["document"]["id"]).toBe(1); + + await collection.archive(); +}); diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index bd5d98d7f..df7bfa417 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "0.9.0" +version = "0.9.1" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/python/README.md b/pgml-sdks/pgml/python/README.md index b6b70aaea..a05c184ce 100644 --- a/pgml-sdks/pgml/python/README.md +++ b/pgml-sdks/pgml/python/README.md @@ -213,9 +213,11 @@ collection = Collection("test_collection", CUSTOM_DATABASE_URL) ### Upserting Documents -Documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. +The `upsert_documents` method can be used to insert new documents and update existing documents. -**Upsert documents with metadata** +New documents are dictionaries with two required keys: `id` and `text`. All other keys/value pairs are stored as metadata for the document. + +**Upsert new documents with metadata** ```python documents = [ { @@ -233,6 +235,97 @@ collection = Collection("test_collection") await collection.upsert_documents(documents) ``` +Document metadata can be updated by upserting the document without the `text` key. + +**Update document metadata** +```python +documents = [ + { + "id": "Document 1", + "random_key": "this will be NEW metadata for the document" + }, + { + "id": "Document 2", + "random_key": "this will be NEW metadata for the document" + } +] +collection = Collection("test_collection") +await collection.upsert_documents(documents) +``` + +### Getting Documents + +Documents can be retrieved using the `get_documents` method on the collection object + +**Get the first 100 documents** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ "limit": 100 }) +``` + +#### Pagination + +The Python SDK supports limit-offset pagination and keyset pagination + +**Limit-Offset pagination** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ "limit": 100, "offset": 10 }) +``` + +**Keyset pagination** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ "limit": 100, "last_row_id": 10 }) +``` + +The `last_row_id` can be taken from the `row_id` field in the returned document's dictionary. + +#### Filtering + +Metadata and full text filtering are supported just like they are in vector recall. + +**Metadata and full text filtering** +```python +collection = Collection("test_collection") +documents = await collection.get_documents({ + "limit": 100, + "offset": 10, + "filter": { + "metadata": { + "id": { + "$eq": 1 + } + }, + "full_text_search": { + "configuration": "english", + "text": "Some full text query" + } + } +}) + +``` + +### Deleting Documents + +Documents can be deleted with the `delete_documents` method on the collection object. + +Metadata and full text filtering are supported just like they are in vector recall. + +```python +documents = await collection.delete_documents({ + "metadata": { + "id": { + "$eq": 1 + } + }, + "full_text_search": { + "configuration": "english", + "text": "Some full text query" + } +}) +``` + ### Searching Collections The Python SDK is specifically designed to provide powerful, flexible vector search. @@ -350,7 +443,7 @@ results = ( .vector_recall("Here is some query", pipeline) .limit(10) .filter({ - "full_text": { + "full_text_search": { "configuration": "english", "text": "Match Me" } diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index 02895348d..9ef3103be 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -3,3 +3,89 @@ def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> Non Json = Any DateTime = int + +# Top of file key: A12BECOD! +from typing import List, Dict, Optional, Self, Any + + +class Builtins: + def __init__(self, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self + ... + def query(self, query: str) -> QueryRunner + ... + async def transform(self, task: Json, inputs: List[str], args: Optional[Json] = Any) -> Json + ... + +class Collection: + def __init__(self, name: str, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self + ... + async def add_pipeline(self, pipeline: Pipeline) -> None + ... + async def remove_pipeline(self, pipeline: Pipeline) -> None + ... + async def enable_pipeline(self, pipeline: Pipeline) -> None + ... + async def disable_pipeline(self, pipeline: Pipeline) -> None + ... + async def upsert_documents(self, documents: List[Json]) -> None + ... + async def get_documents(self, args: Optional[Json] = Any) -> List[Json] + ... + async def delete_documents(self, filter: Json) -> None + ... + async def vector_search(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any, top_k: Optional[int] = 1) -> List[tuple[float, str, Json]] + ... + async def archive(self) -> None + ... + def query(self) -> QueryBuilder + ... + async def get_pipelines(self) -> List[Pipeline] + ... + async def get_pipeline(self, name: str) -> Pipeline + ... + async def exists(self) -> bool + ... + +class Model: + def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", source: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self + ... + +class Pipeline: + def __init__(self, name: str, model: Optional[Model] = Any, splitter: Optional[Splitter] = Any, parameters: Optional[Json] = Any) -> Self + ... + async def get_status(self) -> PipelineSyncData + ... + async def to_dict(self) -> Json + ... + +class QueryBuilder: + def limit(self, limit: int) -> Self + ... + def filter(self, filter: Json) -> Self + ... + def vector_recall(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any) -> Self + ... + async def fetch_all(self) -> List[tuple[float, str, Json]] + ... + def to_full_string(self) -> str + ... + +class QueryRunner: + async def fetch_all(self) -> Json + ... + async def execute(self) -> None + ... + def bind_string(self, bind_value: str) -> Self + ... + def bind_int(self, bind_value: int) -> Self + ... + def bind_float(self, bind_value: float) -> Self + ... + def bind_bool(self, bind_value: bool) -> Self + ... + def bind_json(self, bind_value: Json) -> Self + ... + +class Splitter: + def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self + ... diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 88c19685d..a355b27a8 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -32,6 +32,7 @@ def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: "project": "a10", "floating_uuid": i * 1.01, "uuid": i * 10, + "test": None, "name": "Test Document {}".format(i), } ) @@ -181,6 +182,74 @@ async def test_pipeline_to_dict(): await collection.archive() +################################################### +## Test document related functions ################ +################################################### + + +@pytest.mark.asyncio +async def test_upsert_and_get_documents(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.Pipeline( + "test_p_p_tuagd_0", + model, + splitter, + {"full_text_search": {"active": True, "configuration": "english"}}, + ) + collection = pgml.Collection(name="test_p_c_tuagd_2") + await collection.add_pipeline( + pipeline, + ) + await collection.upsert_documents(generate_dummy_documents(10)) + + documents = await collection.get_documents() + assert len(documents) == 10 + + documents = await collection.get_documents( + {"offset": 1, "limit": 2, "filter": {"metadata": {"id": {"$gt": 0}}}} + ) + assert len(documents) == 2 and documents[0]["document"]["id"] == 2 + last_row_id = documents[-1]["row_id"] + + documents = await collection.get_documents( + { + "filter": { + "metadata": {"id": {"$gt": 3}}, + "full_text_search": {"configuration": "english", "text": "4"}, + }, + "last_row_id": last_row_id, + } + ) + assert len(documents) == 1 and documents[0]["document"]["id"] == 4 + + await collection.archive() + + +@pytest.mark.asyncio +async def test_delete_documents(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.Pipeline( + "test_p_p_tdd_0", + model, + splitter, + {"full_text_search": {"active": True, "configuration": "english"}}, + ) + collection = pgml.Collection("test_p_c_tdd_1") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.delete_documents( + { + "metadata": {"id": {"$gte": 0}}, + "full_text_search": {"configuration": "english", "text": "0"}, + } + ) + documents = await collection.get_documents() + assert len(documents) == 2 and documents[0]["document"]["id"] == 1 + await collection.archive() + + ################################################### ## Test with multiprocessing ###################### ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 3686c1c1b..60465c130 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -10,7 +10,6 @@ pub struct Builtins { use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; - #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index ba0a1af3e..23fe6df42 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -2,6 +2,8 @@ use anyhow::Context; use indicatif::MultiProgress; use itertools::Itertools; use rust_bridge::{alias, alias_methods}; +use sea_query::{Alias, Expr, JoinType, Order, PostgresQueryBuilder, Query}; +use sea_query_binder::SqlxBinder; use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; @@ -9,10 +11,18 @@ use std::borrow::Cow; use std::time::SystemTime; use tracing::{instrument, warn}; +use crate::filter_builder; use crate::{ - get_or_initialize_pool, model::ModelRuntime, models, pipeline::Pipeline, queries, - query_builder, query_builder::QueryBuilder, remote_embeddings::build_remote_embeddings, - splitter::Splitter, types::DateTime, types::Json, utils, + get_or_initialize_pool, + model::ModelRuntime, + models, + pipeline::Pipeline, + queries, query_builder, + query_builder::QueryBuilder, + remote_embeddings::build_remote_embeddings, + splitter::Splitter, + types::{DateTime, IntoTableNameAndSchema, Json, SIden, TryToNumeric}, + utils, }; #[cfg(feature = "python")] @@ -101,6 +111,7 @@ pub struct Collection { new, upsert_documents, get_documents, + delete_documents, get_pipelines, get_pipeline, add_pipeline, @@ -192,7 +203,7 @@ impl Collection { let project_id: i64 = sqlx::query_scalar("INSERT INTO pgml.projects (name, task) VALUES ($1, 'embedding'::pgml.task) ON CONFLICT (name) DO UPDATE SET task = EXCLUDED.task RETURNING id, task::TEXT") .bind(&self.name) - .fetch_one(&mut transaction) + .fetch_one(&mut *transaction) .await?; transaction @@ -202,7 +213,7 @@ impl Collection { let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id) VALUES ($1, $2) ON CONFLICT (name) DO NOTHING RETURNING *") .bind(&self.name) .bind(project_id) - .fetch_one(&mut transaction) + .fetch_one(&mut *transaction) .await?; let collection_database_data = CollectionDatabaseData { @@ -320,7 +331,7 @@ impl Collection { "DROP TABLE IF EXISTS %s", embeddings_table_name )) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; // Need to delete from the tsvectors table only if no other pipelines use the @@ -331,7 +342,7 @@ impl Collection { self.pipelines_table_name)) .bind(parameters["full_text_search"]["configuration"].as_str()) .bind(database_data.id) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; sqlx::query(&query_builder!( @@ -339,7 +350,7 @@ impl Collection { self.pipelines_table_name )) .bind(database_data.id) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; transaction.commit().await?; @@ -541,60 +552,42 @@ impl Collection { /// serde_json::json!({"id": 1, "text": "hello world"}).into(), /// serde_json::json!({"id": 2, "text": "hello world"}).into(), /// ]; - /// collection.upsert_documents(documents, Some(true)).await?; + /// collection.upsert_documents(documents).await?; /// Ok(()) /// } /// ``` #[instrument(skip(self, documents))] - pub async fn upsert_documents( - &mut self, - documents: Vec, - strict: Option, - ) -> anyhow::Result<()> { + pub async fn upsert_documents(&mut self, documents: Vec) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; self.verify_in_database(false).await?; - let strict = strict.unwrap_or(true); - let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); - let documents: anyhow::Result> = documents.into_iter().map(|mut document| { - let document = document - .as_object_mut() - .expect("Documents must be a vector of objects"); - let text = match document.remove("text") { - Some(t) => t, - None => { - if strict { - anyhow::bail!("`text` is not a key in document, throwing error. To supress this error, pass strict: false"); - } else { - warn!("`text` is not a key in document, skipping document. To throw an error instead, pass strict: true"); - } - return Ok(None) - } - }; - let text = text.as_str().context("`text` must be a string")?.to_string(); - - // We don't want the text included in the document metadata, but everything else - // should be in there - let metadata = serde_json::to_value(&document)?.into(); - - let md5_digest = match document.get("id") { - Some(k) => md5::compute(k.to_string().as_bytes()), - None => { - if strict { - anyhow::bail!("`id` is not a key in document, throwing error. To supress this error, pass strict: false"); - } else { - warn!("`id` is not a key in document, skipping document. To throw an error instead, pass strict: true"); - } - return Ok(None) - } - }; - let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - - Ok(Some((source_uuid, text, metadata))) - }).collect(); + let documents: anyhow::Result> = documents + .into_iter() + .map(|mut document| { + let document = document + .as_object_mut() + .expect("Documents must be a vector of objects"); + let text = document + .remove("text") + .map(|t| t.as_str().expect("`text` must be a string").to_string()); + + // We don't want the text included in the document metadata, but everything else + // should be in there + let metadata = serde_json::to_value(&document)?.into(); + + let id = document + .get("id") + .context("`id` must be a key in documen")? + .to_string(); + let md5_digest = md5::compute(id.as_bytes()); + let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + + Ok(Some((source_uuid, text, metadata))) + }) + .collect(); // We could continue chaining the above iterators but types become super annoying to // deal with, especially because we are dealing with async functions. This is much easier to read @@ -606,26 +599,41 @@ impl Collection { // We want the length before we filter out any None values let chunk_len = chunk.len(); // Filter out the None values - let chunk: Vec<&(uuid::Uuid, String, Json)> = + let mut chunk: Vec<&(uuid::Uuid, Option, Json)> = chunk.iter().filter_map(|x| x.as_ref()).collect(); - // Make sure we didn't filter everything out + // If the chunk is empty, we can skip the rest of the loop if chunk.is_empty() { progress_bar.inc(chunk_len as u64); continue; } + // Split the chunk into two groups, one with text, and one with just metadata + let split_index = itertools::partition(&mut chunk, |(_, text, _)| text.is_some()); + let (text_chunk, metadata_chunk) = chunk.split_at(split_index); + + // Start the transaction let mut transaction = pool.begin().await?; - // First delete any documents that already have the same UUID then insert the new ones. + + // Update the metadata + sqlx::query(query_builder!( + "UPDATE %s d SET metadata = v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", + self.documents_table_name + ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) + .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) + .execute(&mut *transaction).await?; + + // First delete any documents that already have the same UUID as documents in + // text_chunk, then insert the new ones. // We are essentially upserting in two steps sqlx::query(&query_builder!( "DELETE FROM %s WHERE source_uuid IN (SELECT source_uuid FROM %s WHERE source_uuid = ANY($1::uuid[]))", self.documents_table_name, self.documents_table_name )). - bind(&chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). - execute(&mut transaction).await?; - let query_string_values = (0..chunk.len()) + bind(&text_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). + execute(&mut *transaction).await?; + let query_string_values = (0..text_chunk.len()) .map(|i| format!("(${}, ${}, ${})", i * 3 + 1, i * 3 + 2, i * 3 + 3)) .collect::>() .join(","); @@ -635,11 +643,10 @@ impl Collection { ); let query = query_builder!(query_string, self.documents_table_name); let mut query = sqlx::query_scalar(&query); - for (source_uuid, text, metadata) in chunk.into_iter() { + for (source_uuid, text, metadata) in text_chunk.iter() { query = query.bind(source_uuid).bind(text).bind(metadata); } - - let ids: Vec = query.fetch_all(&mut transaction).await?; + let ids: Vec = query.fetch_all(&mut *transaction).await?; document_ids.extend(ids); progress_bar.inc(chunk_len as u64); transaction.commit().await?; @@ -655,8 +662,7 @@ impl Collection { /// /// # Arguments /// - /// * `last_id` - The last id of the document to get. If none, starts at 0 - /// * `limit` - The number of documents to get. If none, gets 100 + /// * `args` - The filters and options to apply to the query /// /// # Example /// @@ -665,36 +671,190 @@ impl Collection { /// /// async fn example() -> anyhow::Result<()> { /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.get_documents(None, None).await?; + /// let documents = collection.get_documents(None).await?; /// Ok(()) /// } #[instrument(skip(self))] - pub async fn get_documents( - &mut self, - last_id: Option, - limit: Option, - ) -> anyhow::Result> { + pub async fn get_documents(&mut self, args: Option) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let documents: Vec = sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE id > $1 ORDER BY id ASC LIMIT $2", - self.documents_table_name - )) - .bind(last_id.unwrap_or(0)) - .bind(limit.unwrap_or(100)) - .fetch_all(&pool) - .await?; - documents + + let mut args = args.unwrap_or_default().0; + let args = args.as_object_mut().context("args must be an object")?; + + // Get limit or set it to 1000 + let limit = args + .remove("limit") + .map(|l| l.try_to_u64()) + .unwrap_or(Ok(1000))?; + + let mut query = Query::select(); + query + .from_as( + self.documents_table_name.to_table_tuple(), + SIden::Str("documents"), + ) + .expr(Expr::cust("*")) // Adds the * in SELECT * FROM + .order_by((SIden::Str("documents"), SIden::Str("id")), Order::Asc) + .limit(limit); + + if let Some(last_row_id) = args.remove("last_row_id") { + let last_row_id = last_row_id + .try_to_u64() + .context("last_row_id must be an integer")?; + query.and_where(Expr::col((SIden::Str("documents"), SIden::Str("id"))).gt(last_row_id)); + } + + if let Some(offset) = args.remove("offset") { + let offset = offset.try_to_u64().context("offset must be an integer")?; + query.offset(offset); + } + + if let Some(mut filter) = args.remove("filter") { + let filter = filter + .as_object_mut() + .context("filter must be a Json object")?; + + if let Some(f) = filter.remove("metadata") { + query.cond_where( + filter_builder::FilterBuilder::new(f, "documents", "metadata").build(), + ); + } + if let Some(f) = filter.remove("full_text_search") { + let f = f + .as_object() + .context("Full text filter must be a Json object")?; + let configuration = f + .get("configuration") + .context("In full_text_search `configuration` is required")? + .as_str() + .context("In full_text_search `configuration` must be a string")?; + let filter_text = f + .get("text") + .context("In full_text_search `text` is required")? + .as_str() + .context("In full_text_search `text` must be a string")?; + query + .join_as( + JoinType::InnerJoin, + self.documents_tsvectors_table_name.to_table_tuple(), + Alias::new("documents_tsvectors"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), + ) + .and_where( + Expr::col(( + SIden::Str("documents_tsvectors"), + SIden::Str("configuration"), + )) + .eq(configuration), + ) + .and_where(Expr::cust_with_values( + format!( + "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", + configuration + ), + [filter_text], + )); + } + } + + let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + let documents: Vec = + sqlx::query_as_with(&sql, values).fetch_all(&pool).await?; + Ok(documents .into_iter() - .map(|d| { - serde_json::to_value(d) - .map(|t| t.into()) - .map_err(|e| anyhow::anyhow!(e)) - }) - .collect() + .map(|d| d.into_user_friendly_json()) + .collect()) + } + /// Deletes documents in a [Collection] + /// + /// # Arguments + /// + /// * `filter` - The filters to apply + /// + /// # Example + /// + /// ``` + /// use pgml::Collection; + /// + /// async fn example() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None); + /// let documents = collection.delete_documents(serde_json::json!({ + /// "metadata": { + /// "id": { + /// "eq": 1 + /// } + /// } + /// }).into()).await?; + /// Ok(()) + /// } + #[instrument(skip(self))] + pub async fn delete_documents(&mut self, mut filter: Json) -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&self.database_url).await?; + + let mut query = Query::delete(); + query.from_table(self.documents_table_name.to_table_tuple()); + + let filter = filter + .as_object_mut() + .context("filter must be a Json object")?; + + if let Some(f) = filter.remove("metadata") { + query + .cond_where(filter_builder::FilterBuilder::new(f, "documents", "metadata").build()); + } + + if let Some(mut f) = filter.remove("full_text_search") { + let f = f + .as_object_mut() + .context("Full text filter must be a Json object")?; + let configuration = f + .get("configuration") + .context("In full_text_search `configuration` is required")? + .as_str() + .context("In full_text_search `configuration` must be a string")?; + let filter_text = f + .get("text") + .context("In full_text_search `text` is required")? + .as_str() + .context("In full_text_search `text` must be a string")?; + let mut inner_select_query = Query::select(); + inner_select_query + .from_as( + self.documents_tsvectors_table_name.to_table_tuple(), + SIden::Str("documents_tsvectors"), + ) + .column(SIden::Str("document_id")) + .and_where(Expr::cust_with_values( + format!( + "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", + configuration + ), + [filter_text], + )) + .and_where( + Expr::col(( + SIden::Str("documents_tsvectors"), + SIden::Str("configuration"), + )) + .eq(configuration), + ); + query.and_where( + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .in_subquery(inner_select_query), + ); + } + + let (sql, values) = query.build_sqlx(PostgresQueryBuilder); + sqlx::query_with(&sql, values).fetch_all(&pool).await?; + Ok(()) } #[instrument(skip(self))] - pub async fn sync_pipelines(&mut self, document_ids: Option>) -> anyhow::Result<()> { + pub(crate) async fn sync_pipelines( + &mut self, + document_ids: Option>, + ) -> anyhow::Result<()> { self.verify_in_database(false).await?; let pipelines = self.get_pipelines().await?; if !pipelines.is_empty() { @@ -711,10 +871,6 @@ impl Collection { .expect("Failed to execute pipeline"); }) .await; - // pipelines.into_iter().for_each - // for mut pipeline in pipelines { - // pipeline.execute(&document_ids, mp.clone()).await?; - // } eprintln!("Done Syncing Pipelines\n"); } Ok(()) @@ -878,14 +1034,14 @@ impl Collection { sqlx::query("UPDATE pgml.collections SET name = $1, active = FALSE where name = $2") .bind(&archive_table_name) .bind(&self.name) - .execute(&mut transaciton) + .execute(&mut *transaciton) .await?; sqlx::query(&query_builder!( "ALTER SCHEMA %s RENAME TO %s", &self.name, archive_table_name )) - .execute(&mut transaciton) + .execute(&mut *transaciton) .await?; transaciton.commit().await?; Ok(()) @@ -1062,45 +1218,3 @@ impl Collection { .unwrap() } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::init_logger; - - #[sqlx::test] - async fn can_upsert_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cud_2", None); - - // Test basic upsert - let documents = vec![ - serde_json::json!({"id": 1, "text": "hello world"}).into(), - serde_json::json!({"text": "hello world"}).into(), - ]; - collection - .upsert_documents(documents.clone(), Some(false)) - .await?; - let document = &collection.get_documents(None, Some(1)).await?[0]; - assert_eq!(document["text"], "hello world"); - - // Test strictness - assert!(collection - .upsert_documents(documents, Some(true)) - .await - .is_err()); - - // Test upsert - let documents = vec![ - serde_json::json!({"id": 1, "text": "hello world 2"}).into(), - serde_json::json!({"text": "hello world"}).into(), - ]; - collection - .upsert_documents(documents.clone(), Some(false)) - .await?; - let document = &collection.get_documents(None, Some(1)).await?[0]; - assert_eq!(document["text"], "hello world 2"); - collection.archive().await?; - Ok(()) - } -} diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index 20e6c1acc..cf32ffa4b 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -116,9 +116,7 @@ fn get_value_type(value: &serde_json::Value) -> String { get_value_type(value) } else if value.is_string() { "text".to_string() - } else if value.is_i64() { - "float8".to_string() - } else if value.is_f64() { + } else if value.is_i64() || value.is_f64() { "float8".to_string() } else if value.is_boolean() { "bool".to_string() @@ -278,29 +276,35 @@ mod tests { } #[test] - fn eq_ne_comparison_operators() { - let basic_comparison_operators = vec!["", "NOT "]; - let basic_comparison_operators_names = vec!["$eq", "$ne"]; - for (operator, name) in basic_comparison_operators - .into_iter() - .zip(basic_comparison_operators_names.into_iter()) - { - let sql = construct_filter_builder_with_json(json!({ - "id": {name: 1}, - "id2": {"id3": {name: "test"}}, - "id4": {"id5": {"id6": {name: true}}}, - "id7": {"id8": {"id9": {"id10": {name: [1, 2, 3]}}}} - })) - .build() - .to_valid_sql_query(); - assert_eq!( - sql, - format!( - r##"SELECT "id" FROM "test_table" WHERE {}"test_table"."metadata" @> E'{{\"id\":1}}' AND {}"test_table"."metadata" @> E'{{\"id2\":{{\"id3\":\"test\"}}}}' AND {}"test_table"."metadata" @> E'{{\"id4\":{{\"id5\":{{\"id6\":true}}}}}}' AND {}"test_table"."metadata" @> E'{{\"id7\":{{\"id8\":{{\"id9\":{{\"id10\":[1,2,3]}}}}}}}}'"##, - operator, operator, operator, operator - ) - ); - } + fn eq_operator() { + let sql = construct_filter_builder_with_json(json!({ + "id": {"$eq": 1}, + "id2": {"id3": {"$eq": "test"}}, + "id4": {"id5": {"id6": {"$eq": true}}}, + "id7": {"id8": {"id9": {"id10": {"$eq": [1, 2, 3]}}}} + })) + .build() + .to_valid_sql_query(); + assert_eq!( + sql, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}' AND ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## + ); + } + + #[test] + fn ne_operator() { + let sql = construct_filter_builder_with_json(json!({ + "id": {"$ne": 1}, + "id2": {"id3": {"$ne": "test"}}, + "id4": {"id5": {"id6": {"$ne": true}}}, + "id7": {"id8": {"id9": {"id10": {"$ne": [1, 2, 3]}}}} + })) + .build() + .to_valid_sql_query(); + assert_eq!( + sql, + r##"SELECT "id" FROM "test_table" WHERE (NOT ("test_table"."metadata")) @> E'{\"id\":1}' AND (NOT ("test_table"."metadata")) @> E'{\"id2\":{\"id3\":\"test\"}}' AND (NOT ("test_table"."metadata")) @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND (NOT ("test_table"."metadata")) @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"## + ); } #[test] @@ -320,7 +324,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} 1 AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} 1"##, + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata"#>>'{{id}}')::float8) {} 1 AND (("test_table"."metadata"#>>'{{id2,id3}}')::float8) {} 1"##, operator, operator ) ); @@ -344,7 +348,7 @@ mod tests { assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} (1) AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} (1)"##, + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata"#>>'{{id}}')::float8) {} (1) AND (("test_table"."metadata"#>>'{{id2,id3}}')::float8) {} (1)"##, operator, operator ) ); @@ -363,7 +367,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"## ); } @@ -379,7 +383,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"## ); } @@ -395,7 +399,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE NOT ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}')"## + r##"SELECT "id" FROM "test_table" WHERE NOT (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}')"## ); } @@ -415,7 +419,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') AND "test_table"."metadata" @> E'{\"id4\":1}'"## + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') AND ("test_table"."metadata") @> E'{\"id4\":1}'"## ); let sql = construct_filter_builder_with_json(json!({ "$or": [ @@ -431,7 +435,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') OR "test_table"."metadata" @> E'{\"id4\":1}'"## + r##"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') OR ("test_table"."metadata") @> E'{\"id4\":1}'"## ); let sql = construct_filter_builder_with_json(json!({ "metadata": {"$or": [ @@ -443,7 +447,7 @@ mod tests { .to_valid_sql_query(); assert_eq!( sql, - r##"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR "test_table"."metadata" @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"## + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR ("test_table"."metadata") @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"## ); } } diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index 6465c408d..06e158be2 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -1,5 +1,5 @@ use neon::prelude::*; -use rust_bridge::javascript::{IntoJsResult, FromJsType}; +use rust_bridge::javascript::{FromJsType, IntoJsResult}; use crate::{ pipeline::PipelineSyncData, @@ -54,7 +54,7 @@ impl IntoJsResult for Json { } Ok(js_object.upcast()) } - _ => panic!("Unsupported type for JSON conversion"), + serde_json::Value::Null => Ok(cx.null().upcast()), } } } @@ -113,6 +113,8 @@ impl FromJsType for Json { json.insert(key, json_value.0); } Ok(Self(serde_json::Value::Object(json))) + } else if arg.is_a::(cx) { + Ok(Self(serde_json::Value::Null)) } else { panic!("Unsupported type for Json conversion"); } diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index 728c2a0ce..3d81c9377 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -40,7 +40,7 @@ impl ToPyObject for Json { } dict.to_object(py) } - _ => panic!("Unsupported type for JSON conversion"), + serde_json::Value::Null => py.None(), } } } @@ -100,6 +100,9 @@ impl FromPyObject<'_> for Json { } Ok(Self(serde_json::Value::Array(json_values))) } else { + if ob.is_none() { + return Ok(Self(serde_json::Value::Null)); + } panic!("Unsupported type for JSON conversion"); } } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 96ee99b6b..8c6c355ec 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -145,7 +145,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { fn js_init_logger( mut cx: neon::context::FunctionContext, ) -> neon::result::JsResult { - use rust_bridge::javascript::{IntoJsResult, FromJsType}; + use rust_bridge::javascript::{FromJsType, IntoJsResult}; let level = cx.argument_opt(0); let level = >::from_option_js_type(&mut cx, level)?; let format = cx.argument_opt(1); @@ -170,6 +170,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { mod tests { use super::*; use crate::{model::Model, pipeline::Pipeline, splitter::Splitter, types::Json}; + use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { let mut documents = Vec::new(); @@ -188,6 +189,10 @@ mod tests { documents } + /////////////////////////////// + // Collection & Pipelines ///// + /////////////////////////////// + #[sqlx::test] async fn can_create_collection() -> anyhow::Result<()> { init_logger(None, None).ok(); @@ -310,7 +315,7 @@ mod tests { collection.add_pipeline(&mut pipeline1).await?; collection.add_pipeline(&mut pipeline2).await?; collection - .upsert_documents(generate_dummy_documents(3), Some(true)) + .upsert_documents(generate_dummy_documents(3)) .await?; let status_1 = pipeline1.get_status().await?; let status_2 = pipeline2.get_status().await?; @@ -326,6 +331,10 @@ mod tests { Ok(()) } + /////////////////////////////// + // Various Searches /////////// + /////////////////////////////// + #[sqlx::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { init_logger(None, None).ok(); @@ -351,7 +360,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .vector_search("Here is some query", &mut pipeline, None, None) @@ -390,7 +399,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .vector_search("Here is some query", &mut pipeline, None, None) @@ -425,7 +434,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .query() @@ -466,7 +475,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection .query() @@ -477,4 +486,438 @@ mod tests { collection.archive().await?; Ok(()) } + + #[sqlx::test] + async fn can_filter_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::new(None, None, None); + let splitter = Splitter::new(None, None); + let mut pipeline = Pipeline::new( + "test_r_p_cfd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + let mut collection = Collection::new("test_r_c_cfd_2", None); + collection.add_pipeline(&mut pipeline).await?; + collection + .upsert_documents(generate_dummy_documents(5)) + .await?; + + let filters = vec![ + (5, json!({}).into()), + ( + 3, + json!({ + "metadata": { + "id": { + "$lt": 3 + } + } + }) + .into(), + ), + ( + 1, + json!({ + "full_text_search": { + "configuration": "english", + "text": "1", + } + }) + .into(), + ), + ]; + + for (expected_result_count, filter) in filters { + let results = collection + .query() + .vector_recall("Here is some query", &mut pipeline, None) + .filter(filter) + .fetch_all() + .await?; + println!("{:?}", results); + assert_eq!(results.len(), expected_result_count); + } + + collection.archive().await?; + Ok(()) + } + + /////////////////////////////// + // Working With Documents ///// + /////////////////////////////// + + #[sqlx::test] + async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cuafgd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + + let mut collection = Collection::new("test_r_c_cuagd_2", None); + collection.add_pipeline(&mut pipeline).await?; + + // Test basic upsert + let documents = vec![ + serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), + serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), + serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), + ]; + collection.upsert_documents(documents.clone()).await?; + let document = &collection.get_documents(None).await?[0]; + assert_eq!(document["document"]["text"], "hello world 1"); + + // Test upsert of text and metadata + let documents = vec![ + serde_json::json!({"id": 1, "text": "hello world new"}).into(), + serde_json::json!({"id": 2, "random_key": 12}).into(), + serde_json::json!({"id": 3, "random_key": 13}).into(), + ]; + collection.upsert_documents(documents.clone()).await?; + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "random_key": { + "$eq": 12 + } + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["text"], "hello world 2"); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "random_key": { + "$gte": 13 + } + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["text"], "hello world 3"); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "full_text_search": { + "configuration": "english", + "text": "new" + } + } + }) + .into(), + )) + .await?; + assert_eq!(documents[0]["document"]["text"], "hello world new"); + assert_eq!(documents[0]["document"]["id"].as_i64().unwrap(), 1); + + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_paginate_get_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cpgd_2", None); + collection + .upsert_documents(generate_dummy_documents(10)) + .await?; + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 5, + "offset": 0 + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![1, 2, 3, 4, 5] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 2, + "offset": 5 + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![6, 7] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 2, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![8, 9] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "limit": 1, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["row_id"].as_i64().unwrap()) + .collect::>(), + vec![10] + ); + + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cfapgd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + + let mut collection = Collection::new("test_r_c_cfapgd_1", None); + collection.add_pipeline(&mut pipeline).await?; + + collection + .upsert_documents(generate_dummy_documents(10)) + .await?; + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "id": { + "$gte": 2 + } + } + }, + "limit": 2, + "offset": 0 + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![2, 3] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "metadata": { + "id": { + "$lte": 5 + } + } + }, + "limit": 100, + "offset": 4 + }) + .into(), + )) + .await?; + let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![4, 5] + ); + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "filter": { + "full_text_search": { + "configuration": "english", + "text": "document" + } + }, + "limit": 100, + "last_row_id": last_row_id + }) + .into(), + )) + .await?; + assert_eq!( + documents + .into_iter() + .map(|d| d["document"]["id"].as_i64().unwrap()) + .collect::>(), + vec![6, 7, 8, 9] + ); + + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_filter_and_delete_documents() -> anyhow::Result<()> { + init_logger(None, None).ok(); + let model = Model::new(None, None, None); + let splitter = Splitter::new(None, None); + let mut pipeline = Pipeline::new( + "test_r_p_cfadd_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + + let mut collection = Collection::new("test_r_c_cfadd_1", None); + collection.add_pipeline(&mut pipeline).await?; + collection + .upsert_documents(generate_dummy_documents(10)) + .await?; + + collection + .delete_documents( + serde_json::json!({ + "metadata": { + "id": { + "$lt": 2 + } + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 8); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() >= 2)); + + collection + .delete_documents( + serde_json::json!({ + "full_text_search": { + "configuration": "english", + "text": "2" + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 7); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() > 2)); + + collection + .delete_documents( + serde_json::json!({ + "metadata": { + "id": { + "$gte": 6 + } + }, + "full_text_search": { + "configuration": "english", + "text": "6" + } + }) + .into(), + ) + .await?; + let documents = collection.get_documents(None).await?; + assert_eq!(documents.len(), 6); + assert!(documents + .iter() + .all(|d| d["document"]["id"].as_i64().unwrap() != 6)); + + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 54f3fc5a0..07b2a1c98 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -20,7 +20,7 @@ use crate::types::JsonPython; /// annoying, but with the traits implimented below is a breeze and can be done just using .into /// Our model runtimes -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ModelRuntime { Python, OpenAI, diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 2b735f5a0..07440d4e3 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -69,6 +69,19 @@ pub struct Document { pub text: String, } +impl Document { + pub fn into_user_friendly_json(mut self) -> Json { + self.metadata["text"] = self.text.into(); + serde_json::json!({ + "row_id": self.id, + "created_at": self.created_at, + "source_uuid": self.source_uuid, + "document": self.metadata, + }) + .into() + } +} + // A collection of documents #[enum_def] #[derive(FromRow)] diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 4e7b2d709..87e632b34 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -659,7 +659,7 @@ impl Pipeline { ), embedding_length )) - .execute(&mut transaction) + .execute(&mut *transaction) .await?; transaction .execute( diff --git a/pgml-sdks/pgml/src/query_builder/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs similarity index 91% rename from pgml-sdks/pgml/src/query_builder/query_builder.rs rename to pgml-sdks/pgml/src/query_builder.rs index 9f70b49cc..a759cc7e4 100644 --- a/pgml-sdks/pgml/src/query_builder/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -1,54 +1,25 @@ use anyhow::Context; -use itertools::Itertools; use rust_bridge::{alias, alias_methods}; use sea_query::{ - query::SelectStatement, Alias, CommonTableExpression, Expr, Func, Iden, JoinType, Order, + query::SelectStatement, Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, QueryStatementWriter, WithClause, }; use sea_query_binder::SqlxBinder; use std::borrow::Cow; use crate::{ - filter_builder, get_or_initialize_pool, models, pipeline::Pipeline, - remote_embeddings::build_remote_embeddings, types::Json, Collection, + filter_builder, get_or_initialize_pool, + model::ModelRuntime, + models, + pipeline::Pipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, + Collection, }; #[cfg(feature = "python")] use crate::{pipeline::PipelinePython, types::JsonPython}; -#[derive(Clone)] -enum SIden<'a> { - Str(&'a str), - String(String), -} - -impl Iden for SIden<'_> { - fn unquoted(&self, s: &mut dyn std::fmt::Write) { - write!( - s, - "{}", - match self { - SIden::Str(s) => s, - SIden::String(s) => s.as_str(), - } - ) - .unwrap(); - } -} - -trait IntoTableNameAndSchema { - fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>); -} - -impl IntoTableNameAndSchema for String { - fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>) { - self.split('.') - .map(|s| SIden::String(s.to_string())) - .collect_tuple() - .expect("Malformed table name in IntoTableNameAndSchema") - } -} - #[derive(Clone, Debug)] struct QueryBuilderState {} @@ -88,7 +59,7 @@ impl QueryBuilder { if let Some(f) = filter.remove("metadata") { self = self.filter_metadata(f); } - if let Some(f) = filter.remove("full_text") { + if let Some(f) = filter.remove("full_text_search") { self = self.filter_full_text(f); } self @@ -131,7 +102,7 @@ impl QueryBuilder { .eq(configuration), ) .and_where(Expr::cust_with_values( - &format!( + format!( "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", configuration ), @@ -273,6 +244,11 @@ impl QueryBuilder { .as_ref() .context("Pipeline must be verified to perform vector search with remote embeddings")?; + // If the model runtime is python, the error was not caused by an unsupported runtime + if model.runtime == ModelRuntime::Python { + return Err(anyhow::anyhow!(e)); + } + let query_parameters = self.query_parameters.to_owned().unwrap_or_default(); let remote_embeddings = diff --git a/pgml-sdks/pgml/src/query_builder/mod.rs b/pgml-sdks/pgml/src/query_builder/mod.rs deleted file mode 100644 index 102e40e0b..000000000 --- a/pgml-sdks/pgml/src/query_builder/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod query_builder; -pub use query_builder::*; diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index fa390517e..d3d1ce306 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -1,4 +1,7 @@ +use anyhow::Context; +use itertools::Itertools; use rust_bridge::alias_manual; +use sea_query::Iden; use serde::Serialize; use std::ops::{Deref, DerefMut}; @@ -39,6 +42,27 @@ impl Serialize for Json { } } +pub(crate) trait TryToNumeric { + fn try_to_u64(&self) -> anyhow::Result; +} + +impl TryToNumeric for serde_json::Value { + fn try_to_u64(&self) -> anyhow::Result { + match self { + serde_json::Value::Number(n) => { + if n.is_f64() { + Ok(n.as_f64().unwrap() as u64) + } else if n.is_i64() { + Ok(n.as_i64().unwrap() as u64) + } else { + n.as_u64().context("limit must be an integer") + } + } + _ => Err(anyhow::anyhow!("Json value is not a number")), + } + } +} + /// A wrapper around sqlx::types::chrono::DateTime #[derive(sqlx::Type, Debug, Clone)] #[sqlx(transparent)] @@ -50,3 +74,36 @@ impl Serialize for DateTime { self.0.timestamp().serialize(serializer) } } + +#[derive(Clone)] +pub(crate) enum SIden<'a> { + Str(&'a str), + String(String), +} + +impl Iden for SIden<'_> { + fn unquoted(&self, s: &mut dyn std::fmt::Write) { + write!( + s, + "{}", + match self { + SIden::Str(s) => s, + SIden::String(s) => s.as_str(), + } + ) + .unwrap(); + } +} + +pub(crate) trait IntoTableNameAndSchema { + fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>); +} + +impl IntoTableNameAndSchema for String { + fn to_table_tuple<'b>(&self) -> (SIden<'b>, SIden<'b>) { + self.split('.') + .map(|s| SIden::String(s.to_string())) + .collect_tuple() + .expect("Malformed table name in IntoTableNameAndSchema") + } +}