From 3f615d0644924c4c6da25ec08b0df87b9196afc8 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 3 Jan 2024 12:26:20 -0800 Subject: [PATCH 1/4] update deps --- pgml-extension/Cargo.lock | 176 +++++++++++++++++++------------------- 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index acf5e52f2..a471c019e 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -34,9 +34,9 @@ checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anyhow" -version = "1.0.77" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9d19de80eff169429ac1e9f48fffb163916b448a44e8e046186232046d9e1f9" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" [[package]] name = "approx" @@ -120,13 +120,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.75" +version = "0.1.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" +checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -210,11 +210,11 @@ dependencies = [ "peeking_take_while", "prettyplease", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "regex", "rustc-hash", "shlex", - "syn 2.0.43", + "syn 2.0.46", "which", ] @@ -358,9 +358,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clang-sys" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f" +checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" dependencies = [ "glob", "libc", @@ -405,8 +405,8 @@ checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -564,7 +564,7 @@ dependencies = [ "fnv", "ident_case", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "strsim", "syn 1.0.109", ] @@ -576,15 +576,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ "darling_core", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", ] [[package]] name = "deranged" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", ] @@ -627,7 +627,7 @@ checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" dependencies = [ "darling", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", ] @@ -742,8 +742,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -754,9 +754,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "erased-serde" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4adbf0983fe06bd3a5c19c8477a637c2389feb0994eca7a59e3b961054aa7c0a" +checksum = "55d05712b2d8d88102bc9868020c9e5c7a1f5527c452b9b97450a1d006140ba7" dependencies = [ "serde", ] @@ -886,8 +886,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -1149,12 +1149,12 @@ checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libloading" -version = "0.7.4" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" dependencies = [ "cfg-if", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -1609,8 +1609,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -1790,7 +1790,7 @@ checksum = "a18ac8628b7de2f29a93d0abdbdcaee95a0e0ef4b59fd4de99cc117e166e843b" dependencies = [ "pgrx-sql-entity-graph", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", ] @@ -1828,7 +1828,7 @@ dependencies = [ "pgrx-pg-config", "pgrx-sql-entity-graph", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "serde", "shlex", "sptr", @@ -1846,7 +1846,7 @@ dependencies = [ "eyre", "petgraph", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", "unescape", ] @@ -1968,19 +1968,19 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prettyplease" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" +checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.43", + "syn 2.0.46", ] [[package]] name = "proc-macro2" -version = "1.0.71" +version = "1.0.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" +checksum = "2de98502f212cfcea8d0bb305bd0f49d7ebdd75b64ba0a68f937d888f4e0d6db" dependencies = [ "unicode-ident", ] @@ -2007,9 +2007,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" dependencies = [ "cfg-if", "indoc", @@ -2024,9 +2024,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" dependencies = [ "once_cell", "target-lexicon", @@ -2034,9 +2034,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" dependencies = [ "libc", "pyo3-build-config", @@ -2044,26 +2044,26 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" dependencies = [ "proc-macro2", "pyo3-macros-backend", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] name = "pyo3-macros-backend" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" dependencies = [ "heck", "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -2080,9 +2080,9 @@ checksum = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -2279,7 +2279,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 1.0.20", + "semver 1.0.21", ] [[package]] @@ -2404,9 +2404,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" +checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" [[package]] name = "semver-parser" @@ -2425,9 +2425,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.193" +version = "1.0.194" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +checksum = "0b114498256798c94a0689e1a15fec6005dee8ac1f41de56404b67afc2a4b773" dependencies = [ "serde_derive", ] @@ -2444,20 +2444,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.193" +version = "1.0.194" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +checksum = "a3385e45322e8f9931410f01b3031ec534c3947d0e94c18049af4d9f9907d4e0" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] name = "serde_json" -version = "1.0.108" +version = "1.0.110" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +checksum = "6fbd975230bada99c8bb618e0c365c2eefa219158d5c6c29610fd09ff1833257" dependencies = [ "indexmap 2.1.0", "itoa", @@ -2659,18 +2659,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "unicode-ident", ] [[package]] name = "syn" -version = "2.0.43" +version = "2.0.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" +checksum = "89456b690ff72fddcecf231caedbe615c59480c93358a93dfae7fc29e3ebbf0e" dependencies = [ "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "unicode-ident", ] @@ -2723,9 +2723,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.12" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "tempfile" @@ -2753,22 +2753,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -2943,9 +2943,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "typetag" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196976efd4a62737b3a2b662cda76efb448d099b1049613d7a5d72743c611ce0" +checksum = "c43148481c7b66502c48f35b8eef38b6ccdc7a9f04bd4cc294226d901ccc9bc7" dependencies = [ "erased-serde", "inventory", @@ -2956,13 +2956,13 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eea6765137e2414c44c7b1e07c73965a118a72c46148e1e168b3fc9d3ccf3aa" +checksum = "291db8a81af4840c10d636e047cac67664e343be44e24dfdbd1492df9a5d3390" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -3125,8 +3125,8 @@ dependencies = [ "log", "once_cell", "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", "wasm-bindgen-shared", ] @@ -3136,7 +3136,7 @@ version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ - "quote 1.0.33", + "quote 1.0.35", "wasm-bindgen-macro-support", ] @@ -3147,8 +3147,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3356,9 +3356,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.31" +version = "0.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a4882e6b134d6c28953a387571f1acdd3496830d5e36c5e3a1075580ea641c" +checksum = "8434aeec7b290e8da5c3f0d628cb0eac6cabcb31d14bb74f779a08109a5914d6" dependencies = [ "memchr", ] From e9d0187d012c1fb72274eeba203843d2ccd28dca Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 3 Jan 2024 12:36:13 -0800 Subject: [PATCH 2/4] update more deps --- pgml-extension/Cargo.lock | 10 +++++----- pgml-extension/Cargo.toml | 4 ++-- pgml-extension/rustfmt.toml | 1 + 3 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 pgml-extension/rustfmt.toml diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index a471c019e..fbbb90e9d 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1053,7 +1053,6 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", - "serde", ] [[package]] @@ -1064,6 +1063,7 @@ checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", "hashbrown 0.14.3", + "serde", ] [[package]] @@ -1098,9 +1098,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" dependencies = [ "either", ] @@ -1731,8 +1731,8 @@ dependencies = [ "csv", "flate2", "heapless", - "indexmap 1.9.3", - "itertools 0.11.0", + "indexmap 2.1.0", + "itertools 0.12.0", "lightgbm", "linfa", "linfa-linear", diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index ab5fb00dc..362bb017b 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -24,8 +24,8 @@ csv = "1.2" flate2 = "1.0" blas = { version = "0.22" } blas-src = { version = "0.9", features = ["openblas"] } -indexmap = { version = "1.0", features = ["serde"] } -itertools = "0.11" +indexmap = { version = "2.1", features = ["serde"] } +itertools = "0.12" heapless = "0.7" lightgbm = { git = "https://github.com/postgresml/lightgbm-rs", branch = "main" } linfa = { path = "deps/linfa" } diff --git a/pgml-extension/rustfmt.toml b/pgml-extension/rustfmt.toml new file mode 100644 index 000000000..3ccd3c986 --- /dev/null +++ b/pgml-extension/rustfmt.toml @@ -0,0 +1 @@ +max_width=120 \ No newline at end of file From 80c5ba3f899a0e2e5cdd3ff544e3fbe582b4142a Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 3 Jan 2024 12:36:55 -0800 Subject: [PATCH 3/4] fmt hook --- pgml-extension/src/api.rs | 219 ++++++---------- pgml-extension/src/bindings/langchain/mod.rs | 5 +- pgml-extension/src/bindings/lightgbm.rs | 34 +-- pgml-extension/src/bindings/linfa.rs | 98 ++------ pgml-extension/src/bindings/mod.rs | 32 +-- pgml-extension/src/bindings/python/mod.rs | 11 +- pgml-extension/src/bindings/sklearn/mod.rs | 130 +++------- .../src/bindings/transformers/mod.rs | 79 ++---- .../src/bindings/transformers/transform.rs | 24 +- .../src/bindings/transformers/whitelist.rs | 22 +- pgml-extension/src/bindings/xgboost.rs | 39 +-- pgml-extension/src/lib.rs | 4 +- pgml-extension/src/metrics.rs | 36 +-- pgml-extension/src/orm/algorithm.rs | 4 +- pgml-extension/src/orm/dataset.rs | 19 +- pgml-extension/src/orm/model.rs | 237 ++++++------------ pgml-extension/src/orm/project.rs | 60 ++--- pgml-extension/src/orm/snapshot.rs | 130 +++++----- pgml-extension/src/vectors.rs | 222 +++++----------- 19 files changed, 438 insertions(+), 967 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 5a4d8a29a..380bfb330 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -163,21 +163,30 @@ fn train_joint( let task = task.map(|t| Task::from_str(t).unwrap()); let project = match Project::find_by_name(project_name) { Some(project) => project, - None => Project::create(project_name, match task { - Some(task) => task, - None => error!("Project `{}` does not exist. To create a new project, you must specify a `task`.", project_name), - }), + None => Project::create( + project_name, + match task { + Some(task) => task, + None => error!( + "Project `{}` does not exist. To create a new project, you must specify a `task`.", + project_name + ), + }, + ), }; if task.is_some() && task.unwrap() != project.task { - error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + error!( + "Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", + project.name, project.task + ); } let mut snapshot = match relation_name { None => { - let snapshot = project - .last_snapshot() - .expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."); + let snapshot = project.last_snapshot().expect( + "You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model.", + ); info!("Using existing snapshot from {}", snapshot.snapshot_name(),); @@ -302,7 +311,7 @@ fn train_joint( #[pg_extern(name = "deploy")] fn deploy_model( - model_id: i64 + model_id: i64, ) -> TableIterator< 'static, ( @@ -319,8 +328,7 @@ fn deploy_model( ) .unwrap(); - let project_id = - project_id.unwrap_or_else(|| error!("Project does not exist.")); + let project_id = project_id.unwrap_or_else(|| error!("Project does not exist.")); let project = Project::find(project_id).unwrap(); project.deploy(model_id, Strategy::specific); @@ -351,8 +359,7 @@ fn deploy_strategy( ) .unwrap(); - let project_id = - project_id.unwrap_or_else(|| error!("Project named `{}` does not exist.", project_name)); + let project_id = project_id.unwrap_or_else(|| error!("Project named `{}` does not exist.", project_name)); let task = Task::from_str(&task.unwrap()).unwrap(); @@ -367,11 +374,7 @@ fn deploy_strategy( } match strategy { Strategy::best_score => { - let _ = write!( - sql, - "{predicate}\n{}", - task.default_target_metric_sql_order() - ); + let _ = write!(sql, "{predicate}\n{}", task.default_target_metric_sql_order()); } Strategy::most_recent => { @@ -401,22 +404,16 @@ fn deploy_strategy( _ => error!("invalid strategy"), } sql += "\nLIMIT 1"; - let (model_id, algorithm) = Spi::get_two_with_args::( - &sql, - vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ) - .unwrap(); + let (model_id, algorithm) = + Spi::get_two_with_args::(&sql, vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())]) + .unwrap(); let model_id = model_id.expect("No qualified models exist for this deployment."); let algorithm = algorithm.expect("No qualified models exist for this deployment."); let project = Project::find(project_id).unwrap(); project.deploy(model_id, strategy); - TableIterator::new(vec![( - project_name.to_string(), - strategy.to_string(), - algorithm, - )]) + TableIterator::new(vec![(project_name.to_string(), strategy.to_string(), algorithm)]) } #[pg_extern(immutable, parallel_safe, strict, name = "predict")] @@ -446,10 +443,7 @@ fn predict_i64(project_name: &str, features: Vec) -> f32 { #[pg_extern(immutable, parallel_safe, strict, name = "predict")] fn predict_bool(project_name: &str, features: Vec) -> f32 { - predict_f32( - project_name, - features.iter().map(|&i| i as u8 as f32).collect(), - ) + predict_f32(project_name, features.iter().map(|&i| i as u8 as f32).collect()) } #[pg_extern(immutable, parallel_safe, strict, name = "predict_proba")] @@ -507,8 +501,7 @@ fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 { let features_width = snapshot.features_width(); let mut processed = vec![0_f32; features_width]; - let feature_data = - ndarray::ArrayView2::from_shape((1, features_width), &numeric_encoded_features).unwrap(); + let feature_data = ndarray::ArrayView2::from_shape((1, features_width), &numeric_encoded_features).unwrap(); Zip::from(feature_data.columns()) .and(&snapshot.feature_positions) @@ -555,12 +548,10 @@ fn load_dataset( "linnerud" => dataset::load_linnerud(limit), "wine" => dataset::load_wine(limit), _ => { - let rows = - match crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0) - { - Ok(rows) => rows, - Err(e) => error!("{e}"), - }; + let rows = match crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0) { + Ok(rows) => rows, + Err(e) => error!("{e}"), + }; (source.into(), rows as i64) } }; @@ -579,11 +570,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> #[cfg(all(feature = "python", not(feature = "use_as_lib")))] #[pg_extern(immutable, parallel_safe, name = "embed")] -pub fn embed_batch( - transformer: &str, - inputs: Vec<&str>, - kwargs: default!(JsonB, "'{}'"), -) -> Vec> { +pub fn embed_batch(transformer: &str, inputs: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec> { match crate::bindings::transformers::embed(transformer, inputs, &kwargs.0) { Ok(output) => output, Err(e) => error!("{e}"), @@ -673,13 +660,8 @@ pub fn transform_conversational_json( inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), ) -> JsonB { - if !task.0["task"] - .as_str() - .is_some_and(|v| v == "conversational") - { - error!( - "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" - ); + if !task.0["task"].as_str().is_some_and(|v| v == "conversational") { + error!("ARRAY[]::JSONB inputs for transform should only be used with a conversational task"); } match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { Ok(output) => JsonB(output), @@ -697,9 +679,7 @@ pub fn transform_conversational_string( cache: default!(bool, false), ) -> JsonB { if task != "conversational" { - error!( - "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" - ); + error!("ARRAY[]::JSONB inputs for transform should only be used with a conversational task"); } let task_json = json!({ "task": task }); match crate::bindings::transformers::transform(&task_json, &args.0, inputs) { @@ -718,10 +698,9 @@ pub fn transform_stream_json( cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -736,10 +715,9 @@ pub fn transform_stream_string( ) -> SetOfIterator<'static, JsonB> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -752,19 +730,13 @@ pub fn transform_stream_conversational_json( inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { - if !task.0["task"] - .as_str() - .is_some_and(|v| v == "conversational") - { - error!( - "ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task" - ); + if !task.0["task"].as_str().is_some_and(|v| v == "conversational") { + error!("ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task"); } // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -778,16 +750,13 @@ pub fn transform_stream_conversational_string( cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { if task != "conversational" { - error!( - "ARRAY::JSONB inputs for transform_stream should only be used with a conversational task" - ); + error!("ARRAY::JSONB inputs for transform_stream should only be used with a conversational task"); } let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -802,16 +771,8 @@ fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) - #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] -fn generate_batch( - project_name: &str, - inputs: Vec<&str>, - config: default!(JsonB, "'{}'"), -) -> Vec { - match crate::bindings::transformers::generate( - Project::get_deployed_model_id(project_name), - inputs, - config, - ) { +fn generate_batch(project_name: &str, inputs: Vec<&str>, config: default!(JsonB, "'{}'")) -> Vec { + match crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs, config) { Ok(output) => output, Err(e) => error!("{e}"), } @@ -857,14 +818,17 @@ fn tune( }; if task.is_some() && task.unwrap() != project.task { - error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + error!( + "Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", + project.name, project.task + ); } let mut snapshot = match relation_name { None => { - let snapshot = project - .last_snapshot() - .expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."); + let snapshot = project.last_snapshot().expect( + "You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model.", + ); info!("Using existing snapshot from {}", snapshot.snapshot_name(),); @@ -980,20 +944,13 @@ pub fn sklearn_r2_score(ground_truth: Vec, y_hat: Vec) -> f32 { #[cfg(feature = "python")] #[pg_extern(name = "sklearn_regression_metrics")] pub fn sklearn_regression_metrics(ground_truth: Vec, y_hat: Vec) -> JsonB { - let metrics = unwrap_or_error!(crate::bindings::sklearn::regression_metrics( - &ground_truth, - &y_hat, - )); + let metrics = unwrap_or_error!(crate::bindings::sklearn::regression_metrics(&ground_truth, &y_hat,)); JsonB(json!(metrics)) } #[cfg(feature = "python")] #[pg_extern(name = "sklearn_classification_metrics")] -pub fn sklearn_classification_metrics( - ground_truth: Vec, - y_hat: Vec, - num_classes: i64, -) -> JsonB { +pub fn sklearn_classification_metrics(ground_truth: Vec, y_hat: Vec, num_classes: i64) -> JsonB { let metrics = unwrap_or_error!(crate::bindings::sklearn::classification_metrics( &ground_truth, &y_hat, @@ -1006,32 +963,16 @@ pub fn sklearn_classification_metrics( #[pg_extern] pub fn dump_all(path: &str) { let p = std::path::Path::new(path).join("projects.csv"); - Spi::run(&format!( - "COPY pgml.projects TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.projects TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); - Spi::run(&format!( - "COPY pgml.snapshots TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.snapshots TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("models.csv"); - Spi::run(&format!( - "COPY pgml.models TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.models TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("files.csv"); - Spi::run(&format!( - "COPY pgml.files TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.files TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( @@ -1044,11 +985,7 @@ pub fn dump_all(path: &str) { #[pg_extern] pub fn load_all(path: &str) { let p = std::path::Path::new(path).join("projects.csv"); - Spi::run(&format!( - "COPY pgml.projects FROM '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.projects FROM '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); Spi::run(&format!( @@ -1058,18 +995,10 @@ pub fn load_all(path: &str) { .unwrap(); let p = std::path::Path::new(path).join("models.csv"); - Spi::run(&format!( - "COPY pgml.models FROM '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.models FROM '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("files.csv"); - Spi::run(&format!( - "COPY pgml.files FROM '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.files FROM '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( @@ -1630,9 +1559,7 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. - let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'") - .unwrap(); + let setting = Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); info!("Data directory: {}", setting.unwrap()); @@ -1670,9 +1597,7 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. - let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'") - .unwrap(); + let setting = Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); info!("Data directory: {}", setting.unwrap()); @@ -1710,9 +1635,7 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. - let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'") - .unwrap(); + let setting = Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); info!("Data directory: {}", setting.unwrap()); diff --git a/pgml-extension/src/bindings/langchain/mod.rs b/pgml-extension/src/bindings/langchain/mod.rs index 7d8d2582f..d17993df7 100644 --- a/pgml-extension/src/bindings/langchain/mod.rs +++ b/pgml-extension/src/bindings/langchain/mod.rs @@ -18,10 +18,7 @@ pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Result, - ) -> std::result::Result<(), std::fmt::Error> { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { formatter.debug_struct("Estimator").finish() } } @@ -28,10 +25,7 @@ pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Result Result> { +pub fn fit_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> Result> { fit(dataset, hyperparams, Task::classification) } @@ -39,17 +33,11 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result { - hyperparams.insert( - "objective".to_string(), - serde_json::Value::from("regression"), - ); + hyperparams.insert("objective".to_string(), serde_json::Value::from("regression")); } Task::classification => { if dataset.num_distinct_labels > 2 { - hyperparams.insert( - "objective".to_string(), - serde_json::Value::from("multiclass"), - ); + hyperparams.insert("objective".to_string(), serde_json::Value::from("multiclass")); hyperparams.insert( "num_class".to_string(), serde_json::Value::from(dataset.num_distinct_labels), @@ -61,12 +49,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result error!("lightgbm only supports `regression` and `classification` tasks."), }; - let data = lightgbm::Dataset::from_vec( - &dataset.x_train, - &dataset.y_train, - dataset.num_features as i32, - ) - .unwrap(); + let data = lightgbm::Dataset::from_vec(&dataset.x_train, &dataset.y_train, dataset.num_features as i32).unwrap(); let estimator = lightgbm::Booster::train(data, &json! {hyperparams}).unwrap(); @@ -75,12 +58,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result Result> { + fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result> { let results = self.predict_proba(features, num_features)?; Ok(match num_classes { // TODO make lightgbm predict both classes like scikit and xgboost diff --git a/pgml-extension/src/bindings/linfa.rs b/pgml-extension/src/bindings/linfa.rs index d0dbeda47..c2a6fc437 100644 --- a/pgml-extension/src/bindings/linfa.rs +++ b/pgml-extension/src/bindings/linfa.rs @@ -20,11 +20,7 @@ impl LinearRegression { where Self: Sized, { - let records = ArrayView2::from_shape( - (dataset.num_train_rows, dataset.num_features), - &dataset.x_train, - ) - .unwrap(); + let records = ArrayView2::from_shape((dataset.num_train_rows, dataset.num_features), &dataset.x_train).unwrap(); let targets = ArrayView1::from_shape(dataset.num_train_rows, &dataset.y_train).unwrap(); @@ -34,8 +30,7 @@ impl LinearRegression { for (key, value) in hyperparams { match key.as_str() { "fit_intercept" => { - estimator = estimator - .with_intercept(value.as_bool().expect("fit_intercept must be boolean")) + estimator = estimator.with_intercept(value.as_bool().expect("fit_intercept must be boolean")) } _ => bail!("Unknown {}: {:?}", key.as_str(), value), }; @@ -52,14 +47,8 @@ impl LinearRegression { impl Bindings for LinearRegression { /// Predict a novel datapoint. - fn predict( - &self, - features: &[f32], - num_features: usize, - _num_classes: usize, - ) -> Result> { - let records = - ArrayView2::from_shape((features.len() / num_features, num_features), features)?; + fn predict(&self, features: &[f32], num_features: usize, _num_classes: usize) -> Result> { + let records = ArrayView2::from_shape((features.len() / num_features, num_features), features)?; Ok(self.estimator.predict(records).targets.into_raw_vec()) } @@ -96,11 +85,7 @@ impl LogisticRegression { where Self: Sized, { - let records = ArrayView2::from_shape( - (dataset.num_train_rows, dataset.num_features), - &dataset.x_train, - ) - .unwrap(); + let records = ArrayView2::from_shape((dataset.num_train_rows, dataset.num_features), &dataset.x_train).unwrap(); // Copy to convert to i32 because LogisticRegression doesn't continuous targets. let y_train: Vec = dataset.y_train.iter().map(|x| *x as i32).collect(); @@ -114,22 +99,16 @@ impl LogisticRegression { for (key, value) in hyperparams { match key.as_str() { "fit_intercept" => { - estimator = estimator - .with_intercept(value.as_bool().expect("fit_intercept must be boolean")) - } - "alpha" => { - estimator = - estimator.alpha(value.as_f64().expect("alpha must be a float") as f32) + estimator = estimator.with_intercept(value.as_bool().expect("fit_intercept must be boolean")) } + "alpha" => estimator = estimator.alpha(value.as_f64().expect("alpha must be a float") as f32), "max_iterations" => { - estimator = estimator.max_iterations( - value.as_i64().expect("max_iterations must be an integer") as u64, - ) + estimator = + estimator.max_iterations(value.as_i64().expect("max_iterations must be an integer") as u64) } "gradient_tolerance" => { - estimator = estimator.gradient_tolerance( - value.as_f64().expect("gradient_tolerance must be a float") as f32, - ) + estimator = estimator + .gradient_tolerance(value.as_f64().expect("gradient_tolerance must be a float") as f32) } _ => bail!("Unknown {}: {:?}", key.as_str(), value), }; @@ -149,22 +128,16 @@ impl LogisticRegression { for (key, value) in hyperparams { match key.as_str() { "fit_intercept" => { - estimator = estimator - .with_intercept(value.as_bool().expect("fit_intercept must be boolean")) - } - "alpha" => { - estimator = - estimator.alpha(value.as_f64().expect("alpha must be a float") as f32) + estimator = estimator.with_intercept(value.as_bool().expect("fit_intercept must be boolean")) } + "alpha" => estimator = estimator.alpha(value.as_f64().expect("alpha must be a float") as f32), "max_iterations" => { - estimator = estimator.max_iterations( - value.as_i64().expect("max_iterations must be an integer") as u64, - ) + estimator = + estimator.max_iterations(value.as_i64().expect("max_iterations must be an integer") as u64) } "gradient_tolerance" => { - estimator = estimator.gradient_tolerance( - value.as_f64().expect("gradient_tolerance must be a float") as f32, - ) + estimator = estimator + .gradient_tolerance(value.as_f64().expect("gradient_tolerance must be a float") as f32) } _ => bail!("Unknown {}: {:?}", key.as_str(), value), }; @@ -187,16 +160,8 @@ impl Bindings for LogisticRegression { bail!("predict_proba is currently only supported by the Python runtime.") } - fn predict( - &self, - features: &[f32], - _num_features: usize, - _num_classes: usize, - ) -> Result> { - let records = ArrayView2::from_shape( - (features.len() / self.num_features, self.num_features), - features, - )?; + fn predict(&self, features: &[f32], _num_features: usize, _num_classes: usize) -> Result> { + let records = ArrayView2::from_shape((features.len() / self.num_features, self.num_features), features)?; Ok(if self.num_distinct_labels > 2 { self.estimator_multi @@ -244,11 +209,7 @@ pub struct Svm { impl Svm { pub fn fit(dataset: &Dataset, hyperparams: &Hyperparams) -> Result> { - let records = ArrayView2::from_shape( - (dataset.num_train_rows, dataset.num_features), - &dataset.x_train, - ) - .unwrap(); + let records = ArrayView2::from_shape((dataset.num_train_rows, dataset.num_features), &dataset.x_train).unwrap(); let targets = ArrayView1::from_shape(dataset.num_train_rows, &dataset.y_train).unwrap(); @@ -264,13 +225,8 @@ impl Svm { for (key, value) in hyperparams { match key.as_str() { - "eps" => { - estimator = estimator.eps(value.as_f64().expect("eps must be a float") as f32) - } - "shrinking" => { - estimator = - estimator.shrinking(value.as_bool().expect("shrinking must be a bool")) - } + "eps" => estimator = estimator.eps(value.as_f64().expect("eps must be a float") as f32), + "shrinking" => estimator = estimator.shrinking(value.as_bool().expect("shrinking must be a bool")), "kernel" => { match value.as_str().expect("kernel must be a string") { "poli" => estimator = estimator.polynomial_kernel(3.0, 1.0), // degree = 3, c = 1.0 as per Scikit @@ -298,14 +254,8 @@ impl Bindings for Svm { } /// Predict a novel datapoint. - fn predict( - &self, - features: &[f32], - num_features: usize, - _num_classes: usize, - ) -> Result> { - let records = - ArrayView2::from_shape((features.len() / num_features, num_features), features)?; + fn predict(&self, features: &[f32], num_features: usize, _num_classes: usize) -> Result> { + let records = ArrayView2::from_shape((features.len() / num_features, num_features), features)?; Ok(self.estimator.predict(records).targets.into_raw_vec()) } diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 79e543490..d877f490a 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -11,19 +11,18 @@ use crate::orm::*; #[macro_export] macro_rules! create_pymodule { ($pyfile:literal) => { - pub static PY_MODULE: once_cell::sync::Lazy< - anyhow::Result>, - > = once_cell::sync::Lazy::new(|| { - pyo3::Python::with_gil(|py| -> anyhow::Result> { - use $crate::bindings::TracebackError; - let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile)); - Ok( - pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") - .format_traceback(py)? - .into(), - ) - }) - }); + pub static PY_MODULE: once_cell::sync::Lazy>> = + once_cell::sync::Lazy::new(|| { + pyo3::Python::with_gil(|py| -> anyhow::Result> { + use $crate::bindings::TracebackError; + let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile)); + Ok( + pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") + .format_traceback(py)? + .into(), + ) + }) + }); }; } @@ -59,12 +58,7 @@ pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result Result>; + fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result>; /// Predict the probability of each class. fn predict_proba(&self, features: &[f32], num_features: usize) -> Result>; diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index 9ab7300c0..84e7505b7 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -16,8 +16,7 @@ create_pymodule!("/src/bindings/python/python.py"); pub fn activate_venv(venv: &str) -> Result { Python::with_gil(|py| { let activate_venv: Py = get_module!(PY_MODULE).getattr(py, "activate_venv")?; - let result: Py = - activate_venv.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))?; + let result: Py = activate_venv.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))?; Ok(result.extract(py)?) }) @@ -39,9 +38,7 @@ pub fn pip_freeze() -> Result> Ok(result.extract(py)?) })?; - Ok(TableIterator::new( - packages.into_iter().map(|package| (package,)), - )) + Ok(TableIterator::new(packages.into_iter().map(|package| (package,)))) } pub fn validate_dependencies() -> Result { @@ -54,9 +51,7 @@ pub fn validate_dependencies() -> Result { match py.import(module) { Ok(_) => (), Err(e) => { - panic!( - "The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}" - ); + panic!("The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"); } } } diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs index 4b8ce6625..bee066b87 100644 --- a/pgml-extension/src/bindings/sklearn/mod.rs +++ b/pgml-extension/src/bindings/sklearn/mod.rs @@ -33,10 +33,7 @@ wrap_fit!(elastic_net_regression, "elastic_net_regression"); wrap_fit!(ridge_regression, "ridge_regression"); wrap_fit!(random_forest_regression, "random_forest_regression"); wrap_fit!(xgboost_regression, "xgboost_regression"); -wrap_fit!( - xgboost_random_forest_regression, - "xgboost_random_forest_regression" -); +wrap_fit!(xgboost_random_forest_regression, "xgboost_random_forest_regression"); wrap_fit!( orthogonal_matching_persuit_regression, "orthogonal_matching_persuit_regression" @@ -50,10 +47,7 @@ wrap_fit!( stochastic_gradient_descent_regression, "stochastic_gradient_descent_regression" ); -wrap_fit!( - passive_aggressive_regression, - "passive_aggressive_regression" -); +wrap_fit!(passive_aggressive_regression, "passive_aggressive_regression"); wrap_fit!(ransac_regression, "ransac_regression"); wrap_fit!(theil_sen_regression, "theil_sen_regression"); wrap_fit!(huber_regression, "huber_regression"); @@ -64,14 +58,8 @@ wrap_fit!(nu_svm_regression, "nu_svm_regression"); wrap_fit!(ada_boost_regression, "ada_boost_regression"); wrap_fit!(bagging_regression, "bagging_regression"); wrap_fit!(extra_trees_regression, "extra_trees_regression"); -wrap_fit!( - gradient_boosting_trees_regression, - "gradient_boosting_trees_regression" -); -wrap_fit!( - hist_gradient_boosting_regression, - "hist_gradient_boosting_regression" -); +wrap_fit!(gradient_boosting_trees_regression, "gradient_boosting_trees_regression"); +wrap_fit!(hist_gradient_boosting_regression, "hist_gradient_boosting_regression"); wrap_fit!(least_angle_regression, "least_angle_regression"); wrap_fit!(lasso_least_angle_regression, "lasso_least_angle_regression"); wrap_fit!(linear_svm_regression, "linear_svm_regression"); @@ -91,10 +79,7 @@ wrap_fit!( "stochastic_gradient_descent_classification" ); wrap_fit!(perceptron_classification, "perceptron_classification"); -wrap_fit!( - passive_aggressive_classification, - "passive_aggressive_classification" -); +wrap_fit!(passive_aggressive_classification, "passive_aggressive_classification"); wrap_fit!(gaussian_process, "gaussian_process"); wrap_fit!(nu_svm_classification, "nu_svm_classification"); wrap_fit!(ada_boost_classification, "ada_boost_classification"); @@ -124,47 +109,41 @@ wrap_fit!(spectral, "spectral_clustering"); wrap_fit!(spectral_bi, "spectral_biclustering"); wrap_fit!(spectral_co, "spectral_coclustering"); -fn fit( - dataset: &Dataset, - hyperparams: &Hyperparams, - algorithm_task: &'static str, -) -> Result> { +fn fit(dataset: &Dataset, hyperparams: &Hyperparams, algorithm_task: &'static str) -> Result> { let hyperparams = serde_json::to_string(hyperparams).unwrap(); - let (estimator, predict, predict_proba) = - Python::with_gil(|py| -> Result<(Py, Py, Py)> { - let module = get_module!(PY_MODULE); + let (estimator, predict, predict_proba) = Python::with_gil(|py| -> Result<(Py, Py, Py)> { + let module = get_module!(PY_MODULE); - let estimator: Py = module.getattr(py, "estimator")?; + let estimator: Py = module.getattr(py, "estimator")?; - let train: Py = estimator.call1( + let train: Py = estimator.call1( + py, + PyTuple::new( py, - PyTuple::new( - py, - &[ - String::from(algorithm_task).into_py(py), - dataset.num_features.into_py(py), - dataset.num_labels.into_py(py), - hyperparams.into_py(py), - ], - ), - )?; - - let estimator: Py = - train.call1(py, PyTuple::new(py, [&dataset.x_train, &dataset.y_train]))?; - - let predict: Py = module - .getattr(py, "predictor")? - .call1(py, PyTuple::new(py, [&estimator]))? - .extract(py)?; + &[ + String::from(algorithm_task).into_py(py), + dataset.num_features.into_py(py), + dataset.num_labels.into_py(py), + hyperparams.into_py(py), + ], + ), + )?; + + let estimator: Py = train.call1(py, PyTuple::new(py, [&dataset.x_train, &dataset.y_train]))?; + + let predict: Py = module + .getattr(py, "predictor")? + .call1(py, PyTuple::new(py, [&estimator]))? + .extract(py)?; - let predict_proba: Py = module - .getattr(py, "predictor_proba")? - .call1(py, PyTuple::new(py, [&estimator]))? - .extract(py)?; + let predict_proba: Py = module + .getattr(py, "predictor_proba")? + .call1(py, PyTuple::new(py, [&estimator]))? + .extract(py)?; - Ok((estimator, predict, predict_proba)) - })?; + Ok((estimator, predict, predict_proba)) + })?; Ok(Box::new(Estimator { estimator, @@ -183,28 +162,15 @@ unsafe impl Send for Estimator {} unsafe impl Sync for Estimator {} impl std::fmt::Debug for Estimator { - fn fmt( - &self, - formatter: &mut std::fmt::Formatter<'_>, - ) -> std::result::Result<(), std::fmt::Error> { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { formatter.debug_struct("Estimator").finish() } } impl Bindings for Estimator { /// Predict a novel datapoint. - fn predict( - &self, - features: &[f32], - _num_features: usize, - _num_classes: usize, - ) -> Result> { - Python::with_gil(|py| { - Ok(self - .predict - .call1(py, PyTuple::new(py, [features]))? - .extract(py)?) - }) + fn predict(&self, features: &[f32], _num_features: usize, _num_classes: usize) -> Result> { + Python::with_gil(|py| Ok(self.predict.call1(py, PyTuple::new(py, [features]))?.extract(py)?)) } fn predict_proba(&self, features: &[f32], _num_features: usize) -> Result> { @@ -220,9 +186,7 @@ impl Bindings for Estimator { fn to_bytes(&self) -> Result> { Python::with_gil(|py| { let save = get_module!(PY_MODULE).getattr(py, "save")?; - Ok(save - .call1(py, PyTuple::new(py, [&self.estimator]))? - .extract(py)?) + Ok(save.call1(py, PyTuple::new(py, [&self.estimator]))?.extract(py)?) }) } @@ -258,12 +222,8 @@ impl Bindings for Estimator { fn sklearn_metric(name: &str, ground_truth: &[f32], y_hat: &[f32]) -> Result { Python::with_gil(|py| { - let calculate_metric = get_module!(PY_MODULE) - .getattr(py, "calculate_metric") - .unwrap(); - let wrapper: Py = calculate_metric - .call1(py, PyTuple::new(py, [name]))? - .extract(py)?; + let calculate_metric = get_module!(PY_MODULE).getattr(py, "calculate_metric").unwrap(); + let wrapper: Py = calculate_metric.call1(py, PyTuple::new(py, [name]))?.extract(py)?; let score: f32 = wrapper .call1(py, PyTuple::new(py, [ground_truth, y_hat]))? @@ -315,11 +275,7 @@ pub fn regression_metrics(ground_truth: &[f32], y_hat: &[f32]) -> Result Result> { +pub fn classification_metrics(ground_truth: &[f32], y_hat: &[f32], num_classes: usize) -> Result> { let mut scores = Python::with_gil(|py| -> Result> { let calculate_metric = get_module!(PY_MODULE).getattr(py, "classification_metrics")?; let scores: HashMap = calculate_metric @@ -337,11 +293,7 @@ pub fn classification_metrics( Ok(scores) } -pub fn cluster_metrics( - num_features: usize, - inputs: &[f32], - labels: &[f32], -) -> Result> { +pub fn cluster_metrics(num_features: usize, inputs: &[f32], labels: &[f32]) -> Result> { Python::with_gil(|py| { let calculate_metric = get_module!(PY_MODULE).getattr(py, "cluster_metrics")?; diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 9a8528ddb..b300d84e3 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -33,18 +33,12 @@ pub fn get_model_from(task: &Value) -> Result { }) } -pub fn embed( - transformer: &str, - inputs: Vec<&str>, - kwargs: &serde_json::Value, -) -> Result>> { +pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -> Result>> { crate::bindings::python::activate()?; let kwargs = serde_json::to_string(kwargs)?; Python::with_gil(|py| -> Result>> { - let embed: Py = get_module!(PY_MODULE) - .getattr(py, "embed") - .format_traceback(py)?; + let embed: Py = get_module!(PY_MODULE).getattr(py, "embed").format_traceback(py)?; let output = embed .call1( py, @@ -63,21 +57,14 @@ pub fn embed( }) } -pub fn tune( - task: &Task, - dataset: TextDataset, - hyperparams: &JsonB, - path: &Path, -) -> Result> { +pub fn tune(task: &Task, dataset: TextDataset, hyperparams: &JsonB, path: &Path) -> Result> { crate::bindings::python::activate()?; let task = task.to_string(); let hyperparams = serde_json::to_string(&hyperparams.0)?; Python::with_gil(|py| -> Result> { - let tune = get_module!(PY_MODULE) - .getattr(py, "tune") - .format_traceback(py)?; + let tune = get_module!(PY_MODULE).getattr(py, "tune").format_traceback(py)?; let path = path.to_string_lossy(); let output = tune .call1( @@ -102,9 +89,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result Result> { - let generate = get_module!(PY_MODULE) - .getattr(py, "generate") - .format_traceback(py)?; + let generate = get_module!(PY_MODULE).getattr(py, "generate").format_traceback(py)?; let config = serde_json::to_string(&config.0)?; // cloning inputs in case we have to re-call on error is rather unfortunate here // similarly, using a json string to pass kwargs is also unfortunate extra parsing @@ -130,14 +115,10 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result Result<()> { } std::fs::create_dir_all(&dir).context("failed to create directory while dumping model")?; Spi::connect(|client| -> Result<()> { - let result = client.select("SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC", - None, - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), - ]) - )?; + let result = client.select( + "SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC", + None, + Some(vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())]), + )?; for row in result { let mut path = dir.clone(); path.push( row.get::(1)? .ok_or(anyhow!("row get ordinal 1 returned None"))?, ); - let data: Vec = row - .get(3)? - .ok_or(anyhow!("row get ordinal 3 returned None"))?; - let mut file = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(path)?; + let data: Vec = row.get(3)?.ok_or(anyhow!("row get ordinal 3 returned None"))?; + let mut file = std::fs::OpenOptions::new().create(true).append(true).open(path)?; let _num_bytes = file.write(&data)?; file.flush()?; @@ -217,9 +192,7 @@ pub fn load_dataset( // Columns are a (name: String, values: Vec) pair let json: serde_json::Value = serde_json::from_str(&dataset)?; - let json = json - .as_object() - .ok_or(anyhow!("dataset json is not object"))?; + let json = json.as_object().ok_or(anyhow!("dataset json is not object"))?; let types = json .get("types") .ok_or(anyhow!("dataset json missing `types` key"))? @@ -238,9 +211,7 @@ pub fn load_dataset( let column_types = types .iter() .map(|(name, type_)| -> Result { - let type_ = type_ - .as_str() - .ok_or(anyhow!("expected {type_} to be a json string"))?; + let type_ = type_.as_str().ok_or(anyhow!("expected {type_} to be a json string"))?; let type_ = match type_ { "string" => "TEXT", "dict" | "list" => "JSONB", @@ -276,16 +247,17 @@ pub fn load_dataset( .len(); // Avoid the existence warning by checking the schema for the table first - let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) - ])?.ok_or(anyhow!("table count query returned None"))?; + let table_count = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", + vec![(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())], + )? + .ok_or(anyhow!("table count query returned None"))?; if table_count == 1 { Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#))?; } Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#))?; - let insert = - format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); + let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); for i in 0..num_rows { let mut row = Vec::with_capacity(num_cols); for (name, values) in data { @@ -307,10 +279,7 @@ pub fn load_dataset( .ok_or_else(|| anyhow!("expected {value} to be string"))? .into_datum(), )), - "dict" | "list" => row.push(( - PgBuiltInOids::JSONBOID.oid(), - JsonB(value.clone()).into_datum(), - )), + "dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())), "int64" | "int32" | "int16" => row.push(( PgBuiltInOids::INT8OID.oid(), value diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index fa03984d9..21503f186 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -54,17 +54,12 @@ pub fn transform( let inputs = serde_json::to_string(&inputs)?; let results = Python::with_gil(|py| -> Result { - let transform: Py = get_module!(PY_MODULE) - .getattr(py, "transform") - .format_traceback(py)?; + let transform: Py = get_module!(PY_MODULE).getattr(py, "transform").format_traceback(py)?; let output = transform .call1( py, - PyTuple::new( - py, - &[task.into_py(py), args.into_py(py), inputs.into_py(py)], - ), + PyTuple::new(py, &[task.into_py(py), args.into_py(py), inputs.into_py(py)]), ) .format_traceback(py)?; @@ -87,21 +82,14 @@ pub fn transform_stream( let input = serde_json::to_string(&input)?; Python::with_gil(|py| -> Result> { - let transform: Py = get_module!(PY_MODULE) - .getattr(py, "transform") - .format_traceback(py)?; + let transform: Py = get_module!(PY_MODULE).getattr(py, "transform").format_traceback(py)?; let output = transform .call1( py, PyTuple::new( py, - &[ - task.into_py(py), - args.into_py(py), - input.into_py(py), - true.into_py(py), - ], + &[task.into_py(py), args.into_py(py), input.into_py(py), true.into_py(py)], ), ) .format_traceback(py)?; @@ -115,8 +103,6 @@ pub fn transform_stream_iterator( args: &serde_json::Value, input: T, ) -> Result { - let python_iter = transform_stream(task, args, input) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = transform_stream(task, args, input).map_err(|e| error!("{e}")).unwrap(); Ok(TransformStreamIterator::new(python_iter)) } diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index 3714091d1..0194180c0 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -17,8 +17,7 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { }; let whitelisted_models = config_csv_list(CONFIG_HF_WHITELIST); - let model_is_allowed = - whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); + let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); if !model_is_allowed { bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf"); } @@ -45,13 +44,7 @@ fn config_csv_list(name: &str) -> Vec { Some(value) => value .trim_matches('"') .split(',') - .filter_map(|s| { - if s.is_empty() { - None - } else { - Some(s.to_string()) - } - }) + .filter_map(|s| if s.is_empty() { None } else { Some(s.to_string()) }) .collect(), None => vec![], } @@ -76,13 +69,10 @@ fn get_trust_remote_code(task: &Value) -> Option { // The JSON key for the trust remote code flag static TASK_REMOTE_CODE_KEY: &str = "trust_remote_code"; match task { - Value::Object(map) => map.get(TASK_REMOTE_CODE_KEY).and_then(|v| { - if let Value::Bool(trust) = v { - Some(*trust) - } else { - None - } - }), + Value::Object(map) => { + map.get(TASK_REMOTE_CODE_KEY) + .and_then(|v| if let Value::Bool(trust) = v { Some(*trust) } else { None }) + } _ => None, } } diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index be3d2b09f..3e533d5f3 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -128,9 +128,7 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters { }, "max_leaves" => params.max_leaves(value.as_u64().unwrap() as u32), "max_bin" => params.max_bin(value.as_u64().unwrap() as u32), - "booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => { - &mut params - } // Valid but not relevant to this section + "booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => &mut params, // Valid but not relevant to this section "nthread" => &mut params, "random_state" => &mut params, _ => panic!("Unknown hyperparameter {:?}: {:?}", key, value), @@ -143,10 +141,7 @@ pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Result Result> { +pub fn fit_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> Result> { fit( dataset, hyperparams, @@ -187,12 +182,8 @@ fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective { "gpu:binary:logitraw" => learning::Objective::GpuBinaryLogisticRaw, "count:poisson" => learning::Objective::CountPoisson, "survival:cox" => learning::Objective::SurvivalCox, - "multi:softmax" => { - learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap()) - } - "multi:softprob" => { - learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap()) - } + "multi:softmax" => learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap()), + "multi:softprob" => learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap()), "rank:pairwise" => learning::Objective::RankPairwise, "reg:gamma" => learning::Objective::RegGamma, "reg:tweedie" => learning::Objective::RegTweedie(Some(dataset.num_distinct_labels as f32)), @@ -200,11 +191,7 @@ fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective { } } -fn fit( - dataset: &Dataset, - hyperparams: &Hyperparams, - objective: learning::Objective, -) -> Result> { +fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Objective) -> Result> { // split the train/test data into DMatrix let mut dtrain = DMatrix::from_dense(&dataset.x_train, dataset.num_train_rows).unwrap(); let mut dtest = DMatrix::from_dense(&dataset.x_test, dataset.num_test_rows).unwrap(); @@ -230,9 +217,7 @@ fn fit( .collect(), ) } else { - learning::Metrics::Custom(Vec::from([eval_metric_from_string( - metrics.as_str().unwrap(), - )])) + learning::Metrics::Custom(Vec::from([eval_metric_from_string(metrics.as_str().unwrap())])) } } None => learning::Metrics::Auto, @@ -314,21 +299,13 @@ unsafe impl Send for Estimator {} unsafe impl Sync for Estimator {} impl std::fmt::Debug for Estimator { - fn fmt( - &self, - formatter: &mut std::fmt::Formatter<'_>, - ) -> std::result::Result<(), std::fmt::Error> { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { formatter.debug_struct("Estimator").finish() } } impl Bindings for Estimator { - fn predict( - &self, - features: &[f32], - num_features: usize, - num_classes: usize, - ) -> Result> { + fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result> { let x = DMatrix::from_dense(features, features.len() / num_features)?; let y = self.estimator.predict(&x)?; Ok(match num_classes { diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index ce0bdbeb2..2bf5235d4 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -57,7 +57,9 @@ pub mod pg_test { let option = format!("pgml.venv = '{venv}'"); options.push(Box::leak(option.into_boxed_str())); } else { - println!("If using virtualenv for Python depenencies, set the `PGML_VENV` environment variable for testing"); + println!( + "If using virtualenv for Python depenencies, set the `PGML_VENV` environment variable for testing" + ); } options } diff --git a/pgml-extension/src/metrics.rs b/pgml-extension/src/metrics.rs index b3c1d2b5d..0d674668b 100644 --- a/pgml-extension/src/metrics.rs +++ b/pgml-extension/src/metrics.rs @@ -47,11 +47,7 @@ impl ConfusionMatrix { /// and the predictions. /// `num_classes` is passed it to ensure that all classes /// were present in the test set. - pub fn new( - ground_truth: &ArrayView1, - y_hat: &ArrayView1, - num_classes: usize, - ) -> ConfusionMatrix { + pub fn new(ground_truth: &ArrayView1, y_hat: &ArrayView1, num_classes: usize) -> ConfusionMatrix { // Distinct classes. let mut classes = ground_truth.iter().collect::>(); classes.extend(&mut y_hat.iter().collect::>().into_iter()); @@ -115,22 +111,14 @@ impl ConfusionMatrix { /// Average recall. pub fn recall(&self) -> f32 { - let recalls = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fn_)) - .collect::>(); + let recalls = self.metrics.iter().map(|m| m.tp / (m.tp + m.fn_)).collect::>(); recalls.iter().sum::() / recalls.len() as f32 } /// Average precision. pub fn precision(&self) -> f32 { - let precisions = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fp)) - .collect::>(); + let precisions = self.metrics.iter().map(|m| m.tp / (m.tp + m.fp)).collect::>(); precisions.iter().sum::() / precisions.len() as f32 } @@ -162,16 +150,8 @@ impl ConfusionMatrix { /// Calculate f1 using the average of class f1's. /// This gives equal opportunity to each class to impact the overall score. fn f1_macro(&self) -> f32 { - let recalls = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fn_)) - .collect::>(); - let precisions = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fp)) - .collect::>(); + let recalls = self.metrics.iter().map(|m| m.tp / (m.tp + m.fn_)).collect::>(); + let precisions = self.metrics.iter().map(|m| m.tp / (m.tp + m.fp)).collect::>(); let mut f1s = Vec::new(); @@ -194,11 +174,7 @@ mod test { let ground_truth = array![1, 2, 3, 4, 4]; let y_hat = array![1, 2, 3, 4, 4]; - let mat = ConfusionMatrix::new( - &ArrayView1::from(&ground_truth), - &ArrayView1::from(&y_hat), - 4, - ); + let mat = ConfusionMatrix::new(&ArrayView1::from(&ground_truth), &ArrayView1::from(&y_hat), 4); let f1 = mat.f1(Average::Macro); let f1_micro = mat.f1(Average::Micro); diff --git a/pgml-extension/src/orm/algorithm.rs b/pgml-extension/src/orm/algorithm.rs index a8a72d1fb..21a87e3bf 100644 --- a/pgml-extension/src/orm/algorithm.rs +++ b/pgml-extension/src/orm/algorithm.rs @@ -122,9 +122,7 @@ impl std::string::ToString for Algorithm { Algorithm::lasso_least_angle => "lasso_least_angle".to_string(), Algorithm::orthogonal_matching_pursuit => "orthogonal_matching_pursuit".to_string(), Algorithm::bayesian_ridge => "bayesian_ridge".to_string(), - Algorithm::automatic_relevance_determination => { - "automatic_relevance_determination".to_string() - } + Algorithm::automatic_relevance_determination => "automatic_relevance_determination".to_string(), Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent".to_string(), Algorithm::perceptron => "perceptron".to_string(), Algorithm::passive_aggressive => "passive_aggressive".to_string(), diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index 9e22ef0ae..062886a5c 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -94,9 +94,12 @@ impl Display for TextDataset { fn drop_table_if_exists(table_name: &str) { // Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first - let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()) - ]).unwrap().unwrap(); + let table_count = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", + vec![(PgBuiltInOids::TEXTOID.oid(), table_name.into_datum())], + ) + .unwrap() + .unwrap(); if table_count == 1 { Spi::run(&format!(r#"DROP TABLE pgml.{table_name} CASCADE"#)).unwrap(); } @@ -476,15 +479,9 @@ pub fn load_iris(limit: Option) -> (String, i64) { VALUES ($1, $2, $3, $4, $5) ", Some(vec![ - ( - PgBuiltInOids::FLOAT4OID.oid(), - row.sepal_length.into_datum(), - ), + (PgBuiltInOids::FLOAT4OID.oid(), row.sepal_length.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.sepal_width.into_datum()), - ( - PgBuiltInOids::FLOAT4OID.oid(), - row.petal_length.into_datum(), - ), + (PgBuiltInOids::FLOAT4OID.oid(), row.petal_length.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.petal_width.into_datum()), (PgBuiltInOids::INT4OID.oid(), row.target.into_datum()), ]), diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index da1940f60..8deebe042 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -21,8 +21,7 @@ use crate::bindings::*; use crate::orm::*; #[allow(clippy::type_complexity)] -static DEPLOYED_MODELS_BY_ID: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); +static DEPLOYED_MODELS_BY_ID: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); #[derive(Debug)] pub struct Model { @@ -197,10 +196,7 @@ impl Model { hyperparams: result.get(6).unwrap().unwrap(), status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), - search: result - .get(9) - .unwrap() - .map(|search| Search::from_str(search).unwrap()), + search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()), search_params: result.get(10).unwrap().unwrap(), search_args: result.get(11).unwrap().unwrap(), created_at: result.get(12).unwrap().unwrap(), @@ -251,11 +247,15 @@ impl Model { "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, $2, $3, $4) RETURNING id", vec![ (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), path.file_name().unwrap().to_str().into_datum()), + ( + PgBuiltInOids::TEXTOID.oid(), + path.file_name().unwrap().to_str().into_datum(), + ), (PgBuiltInOids::INT8OID.oid(), (i as i64).into_datum()), (PgBuiltInOids::BYTEAOID.oid(), chunk.into_datum()), ], - ).unwrap(); + ) + .unwrap(); } } @@ -360,10 +360,7 @@ impl Model { hyperparams: result.get(6).unwrap().unwrap(), status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), - search: result - .get(9) - .unwrap() - .map(|search| Search::from_str(search).unwrap()), + search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()), search_params: result.get(10).unwrap().unwrap(), search_args: result.get(11).unwrap().unwrap(), created_at: result.get(12).unwrap().unwrap(), @@ -379,12 +376,7 @@ impl Model { Ok(()) })?; - model.ok_or_else(|| { - anyhow!( - "pgml.models WHERE id = {:?} could not be loaded. Does it exist?", - id - ) - }) + model.ok_or_else(|| anyhow!("pgml.models WHERE id = {:?} could not be loaded. Does it exist?", id)) } pub fn find_cached(id: i64) -> Result> { @@ -443,16 +435,12 @@ impl Model { Algorithm::random_forest => sklearn::random_forest_regression, Algorithm::xgboost => sklearn::xgboost_regression, Algorithm::xgboost_random_forest => sklearn::xgboost_random_forest_regression, - Algorithm::orthogonal_matching_pursuit => { - sklearn::orthogonal_matching_persuit_regression - } + Algorithm::orthogonal_matching_pursuit => sklearn::orthogonal_matching_persuit_regression, Algorithm::bayesian_ridge => sklearn::bayesian_ridge_regression, Algorithm::automatic_relevance_determination => { sklearn::automatic_relevance_determination_regression } - Algorithm::stochastic_gradient_descent => { - sklearn::stochastic_gradient_descent_regression - } + Algorithm::stochastic_gradient_descent => sklearn::stochastic_gradient_descent_regression, Algorithm::passive_aggressive => sklearn::passive_aggressive_regression, Algorithm::ransac => sklearn::ransac_regression, Algorithm::theil_sen => sklearn::theil_sen_regression, @@ -464,9 +452,7 @@ impl Model { Algorithm::ada_boost => sklearn::ada_boost_regression, Algorithm::bagging => sklearn::bagging_regression, Algorithm::extra_trees => sklearn::extra_trees_regression, - Algorithm::gradient_boosting_trees => { - sklearn::gradient_boosting_trees_regression - } + Algorithm::gradient_boosting_trees => sklearn::gradient_boosting_trees_regression, Algorithm::hist_gradient_boosting => sklearn::hist_gradient_boosting_regression, Algorithm::least_angle => sklearn::least_angle_regression, Algorithm::lasso_least_angle => sklearn::lasso_least_angle_regression, @@ -481,12 +467,8 @@ impl Model { Algorithm::ridge => sklearn::ridge_classification, Algorithm::random_forest => sklearn::random_forest_classification, Algorithm::xgboost => sklearn::xgboost_classification, - Algorithm::xgboost_random_forest => { - sklearn::xgboost_random_forest_classification - } - Algorithm::stochastic_gradient_descent => { - sklearn::stochastic_gradient_descent_classification - } + Algorithm::xgboost_random_forest => sklearn::xgboost_random_forest_classification, + Algorithm::stochastic_gradient_descent => sklearn::stochastic_gradient_descent_classification, Algorithm::perceptron => sklearn::perceptron_classification, Algorithm::passive_aggressive => sklearn::passive_aggressive_classification, Algorithm::gaussian_process => sklearn::gaussian_process, @@ -494,12 +476,8 @@ impl Model { Algorithm::ada_boost => sklearn::ada_boost_classification, Algorithm::bagging => sklearn::bagging_classification, Algorithm::extra_trees => sklearn::extra_trees_classification, - Algorithm::gradient_boosting_trees => { - sklearn::gradient_boosting_trees_classification - } - Algorithm::hist_gradient_boosting => { - sklearn::hist_gradient_boosting_classification - } + Algorithm::gradient_boosting_trees => sklearn::gradient_boosting_trees_classification, + Algorithm::hist_gradient_boosting => sklearn::hist_gradient_boosting_classification, Algorithm::linear_svm => sklearn::linear_svm_classification, Algorithm::lightgbm => sklearn::lightgbm_classification, Algorithm::catboost => sklearn::catboost_classification, @@ -531,17 +509,17 @@ impl Model { } for (key, values) in self.search_params.0.as_object().unwrap() { if all_hyperparam_names.contains(key) { - error!("`{key}` cannot be present in both hyperparams and search_params. Please choose one or the other."); + error!( + "`{key}` cannot be present in both hyperparams and search_params. Please choose one or the other." + ); } all_hyperparam_names.push(key.to_string()); all_hyperparam_values.push(values.as_array().unwrap().to_vec()); } // The search space is all possible combinations - let all_hyperparam_values: Vec> = all_hyperparam_values - .into_iter() - .multi_cartesian_product() - .collect(); + let all_hyperparam_values: Vec> = + all_hyperparam_values.into_iter().multi_cartesian_product().collect(); let mut all_hyperparam_values = match self.search { Some(Search::random) => { // TODO support things like ranges to be random sampled @@ -587,17 +565,10 @@ impl Model { Task::regression => { #[cfg(all(feature = "python", any(test, feature = "pg_test")))] { - let sklearn_metrics = - crate::bindings::sklearn::regression_metrics(y_test, &y_hat).unwrap(); + let sklearn_metrics = crate::bindings::sklearn::regression_metrics(y_test, &y_hat).unwrap(); metrics.insert("sklearn_r2".to_string(), sklearn_metrics["r2"]); - metrics.insert( - "sklearn_mean_absolute_error".to_string(), - sklearn_metrics["mae"], - ); - metrics.insert( - "sklearn_mean_squared_error".to_string(), - sklearn_metrics["mse"], - ); + metrics.insert("sklearn_mean_absolute_error".to_string(), sklearn_metrics["mae"]); + metrics.insert("sklearn_mean_squared_error".to_string(), sklearn_metrics["mse"]); } let y_test = ArrayView1::from(&y_test); @@ -616,12 +587,9 @@ impl Model { Task::classification => { #[cfg(all(feature = "python", any(test, feature = "pg_test")))] { - let sklearn_metrics = crate::bindings::sklearn::classification_metrics( - y_test, - &y_hat, - dataset.num_distinct_labels, - ) - .unwrap(); + let sklearn_metrics = + crate::bindings::sklearn::classification_metrics(y_test, &y_hat, dataset.num_distinct_labels) + .unwrap(); if dataset.num_distinct_labels == 2 { metrics.insert("sklearn_roc_auc".to_string(), sklearn_metrics["roc_auc"]); @@ -629,10 +597,7 @@ impl Model { metrics.insert("sklearn_f1".to_string(), sklearn_metrics["f1"]); metrics.insert("sklearn_f1_micro".to_string(), sklearn_metrics["f1_micro"]); - metrics.insert( - "sklearn_precision".to_string(), - sklearn_metrics["precision"], - ); + metrics.insert("sklearn_precision".to_string(), sklearn_metrics["precision"]); metrics.insert("sklearn_recall".to_string(), sklearn_metrics["recall"]); metrics.insert("sklearn_accuracy".to_string(), sklearn_metrics["accuracy"]); metrics.insert("sklearn_mcc".to_string(), sklearn_metrics["mcc"]); @@ -646,10 +611,7 @@ impl Model { let y_hat = ArrayView1::from(&y_hat).mapv(Pr::new); let y_test: Vec = y_test.iter().map(|&i| i == 1.).collect(); - metrics.insert( - "roc_auc".to_string(), - y_hat.roc(&y_test).unwrap().area_under_curve(), - ); + metrics.insert("roc_auc".to_string(), y_hat.roc(&y_test).unwrap().area_under_curve()); metrics.insert("log_loss".to_string(), y_hat.log_loss(&y_test).unwrap()); } @@ -662,11 +624,8 @@ impl Model { let confusion_matrix = y_hat.confusion_matrix(y_test).unwrap(); // This has to be identical to Scikit. - let pgml_confusion_matrix = crate::metrics::ConfusionMatrix::new( - &y_test, - &y_hat, - dataset.num_distinct_labels, - ); + let pgml_confusion_matrix = + crate::metrics::ConfusionMatrix::new(&y_test, &y_hat, dataset.num_distinct_labels); // These are validated against Scikit and seem to be correct. metrics.insert( @@ -683,12 +642,9 @@ impl Model { Task::cluster => { #[cfg(feature = "python")] { - let sklearn_metrics = crate::bindings::sklearn::cluster_metrics( - dataset.num_features, - &dataset.x_test, - &y_hat, - ) - .unwrap(); + let sklearn_metrics = + crate::bindings::sklearn::cluster_metrics(dataset.num_features, &dataset.x_test, &y_hat) + .unwrap(); metrics.insert("silhouette".to_string(), sklearn_metrics["silhouette"]); } } @@ -703,10 +659,7 @@ impl Model { dataset: &Dataset, hyperparams: &Hyperparams, ) -> (Box, IndexMap) { - info!( - "Hyperparams: {}", - serde_json::to_string_pretty(hyperparams).unwrap() - ); + info!("Hyperparams: {}", serde_json::to_string_pretty(hyperparams).unwrap()); let fit = self.get_fit_function(); let now = Instant::now(); @@ -749,25 +702,11 @@ impl Model { } pub fn f1(&self) -> f32 { - self.metrics - .as_ref() - .unwrap() - .0 - .get("f1") - .unwrap() - .as_f64() - .unwrap() as f32 + self.metrics.as_ref().unwrap().0.get("f1").unwrap().as_f64().unwrap() as f32 } pub fn r2(&self) -> f32 { - self.metrics - .as_ref() - .unwrap() - .0 - .get("r2") - .unwrap() - .as_f64() - .unwrap() as f32 + self.metrics.as_ref().unwrap().0.get("r2").unwrap().as_f64().unwrap() as f32 } fn fit(&mut self, dataset: &Dataset) { @@ -955,9 +894,13 @@ impl Model { "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", vec![ (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), self.bindings.as_ref().unwrap().to_bytes().into_datum()), + ( + PgBuiltInOids::BYTEAOID.oid(), + self.bindings.as_ref().unwrap().to_bytes().into_datum(), + ), ], - ).unwrap(); + ) + .unwrap(); } pub fn numeric_encode_features(&self, rows: &[pgrx::datum::AnyElement]) -> Vec { @@ -976,68 +919,47 @@ impl Model { pgrx_pg_sys::UNKNOWNOID => { error!("Type information missing for column: {:?}. If this is intended to be a TEXT or other categorical column, you will need to explicitly cast it, e.g. change `{:?}` to `CAST({:?} AS TEXT)`.", column.name, column.name, column.name); } - pgrx_pg_sys::TEXTOID - | pgrx_pg_sys::VARCHAROID - | pgrx_pg_sys::BPCHAROID => { + pgrx_pg_sys::TEXTOID | pgrx_pg_sys::VARCHAROID | pgrx_pg_sys::BPCHAROID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index); - element - .unwrap() - .unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) + element.unwrap().unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) } pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::INT2OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::INT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::INT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::FLOAT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::FLOAT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } _ => error!( "Unsupported type for categorical column: {:?}. oid: {:?}", @@ -1055,38 +977,27 @@ impl Model { pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index); - features.push( - element.unwrap().map_or(f32::NAN, |v| v as u8 as f32), - ); + features.push(element.unwrap().map_or(f32::NAN, |v| v as u8 as f32)); } pgrx_pg_sys::INT2OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::INT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::INT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::FLOAT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); features.push(element.unwrap().map_or(f32::NAN, |v| v)); } pgrx_pg_sys::FLOAT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } // TODO handle NULL to NaN for arrays pgrx_pg_sys::BOOLARRAYOID => { @@ -1140,9 +1051,7 @@ impl Model { } } } - _ => error!( - "This preprocessing requires Postgres `record` types created with `row()`." - ), + _ => error!("This preprocessing requires Postgres `record` types created with `row()`."), } } features @@ -1166,11 +1075,11 @@ impl Model { pub fn predict_joint(&self, features: &[f32]) -> Result> { match self.project.task { - Task::regression => self.bindings.as_ref().unwrap().predict( - features, - self.num_features, - self.num_classes, - ), + Task::regression => self + .bindings + .as_ref() + .unwrap() + .predict(features, self.num_features, self.num_classes), Task::classification => { bail!("You can't predict joint probabilities for a classification model") } diff --git a/pgml-extension/src/orm/project.rs b/pgml-extension/src/orm/project.rs index a30db3169..ea23ba80e 100644 --- a/pgml-extension/src/orm/project.rs +++ b/pgml-extension/src/orm/project.rs @@ -8,10 +8,8 @@ use pgrx::*; use crate::orm::*; -static PROJECT_ID_TO_DEPLOYED_MODEL_ID: PgLwLock> = - PgLwLock::new(); -static PROJECT_NAME_TO_PROJECT_ID: Lazy>> = - Lazy::new(|| Mutex::new(HashMap::new())); +static PROJECT_ID_TO_DEPLOYED_MODEL_ID: PgLwLock> = PgLwLock::new(); +static PROJECT_NAME_TO_PROJECT_ID: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); /// Initialize shared memory. /// # Note @@ -56,23 +54,12 @@ impl Project { ); let (project_id, model_id) = match result { Ok(o) => o, - Err(_) => error!( - "No deployed model exists for the project named: `{}`", - project_name - ), + Err(_) => error!("No deployed model exists for the project named: `{}`", project_name), }; - let project_id = project_id.unwrap_or_else(|| { - error!( - "No deployed model exists for the project named: `{}`", - project_name - ) - }); - let model_id = model_id.unwrap_or_else(|| { - error!( - "No deployed model exists for the project named: `{}`", - project_name - ) - }); + let project_id = project_id + .unwrap_or_else(|| error!("No deployed model exists for the project named: `{}`", project_name)); + let model_id = model_id + .unwrap_or_else(|| error!("No deployed model exists for the project named: `{}`", project_name)); projects.insert(project_name.to_string(), project_id); let mut projects = PROJECT_ID_TO_DEPLOYED_MODEL_ID.exclusive(); if projects.len() == 1024 { @@ -83,10 +70,7 @@ impl Project { project_id } }; - *PROJECT_ID_TO_DEPLOYED_MODEL_ID - .share() - .get(&project_id) - .unwrap() + *PROJECT_ID_TO_DEPLOYED_MODEL_ID.share().get(&project_id).unwrap() } pub fn deploy(&self, model_id: i64, strategy: Strategy) { @@ -111,12 +95,14 @@ impl Project { let mut project: Option = None; Spi::connect(|client| { - let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE id = $1 LIMIT 1;", - Some(1), - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), id.into_datum()), - ]) - ).unwrap().first(); + let result = client + .select( + "SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE id = $1 LIMIT 1;", + Some(1), + Some(vec![(PgBuiltInOids::INT8OID.oid(), id.into_datum())]), + ) + .unwrap() + .first(); if !result.is_empty() { project = Some(Project { id: result.get(1).unwrap().unwrap(), @@ -135,12 +121,14 @@ impl Project { let mut project = None; Spi::connect(|client| { - let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE name = $1 LIMIT 1;", - Some(1), - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), - ]) - ).unwrap().first(); + let result = client + .select( + "SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE name = $1 LIMIT 1;", + Some(1), + Some(vec![(PgBuiltInOids::TEXTOID.oid(), name.into_datum())]), + ) + .unwrap() + .first(); if !result.is_empty() { project = Some(Project { id: result.get(1).unwrap().unwrap(), diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 85f697508..6a5973148 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -163,13 +163,10 @@ impl Column { pub(crate) fn scale(&self, value: f32) -> f32 { match self.preprocessor.scale { Scale::standard => (value - self.statistics.mean) / self.statistics.std_dev, - Scale::min_max => { - (value - self.statistics.min) / (self.statistics.max - self.statistics.min) - } + Scale::min_max => (value - self.statistics.min) / (self.statistics.max - self.statistics.min), Scale::max_abs => value / self.statistics.max_abs, Scale::robust => { - (value - self.statistics.median) - / (self.statistics.ventiles[15] - self.statistics.ventiles[5]) + (value - self.statistics.median) / (self.statistics.ventiles[15] - self.statistics.ventiles[5]) } Scale::preserve => value, } @@ -456,10 +453,7 @@ impl Snapshot { LIMIT 1; ", Some(1), - Some(vec![( - PgBuiltInOids::INT8OID.oid(), - project_id.into_datum(), - )]), + Some(vec![(PgBuiltInOids::INT8OID.oid(), project_id.into_datum())]), ) .unwrap() .first(); @@ -467,8 +461,7 @@ impl Snapshot { let jsonb: JsonB = result.get(7).unwrap().unwrap(); let columns: Vec = serde_json::from_value(jsonb.0).unwrap(); let jsonb: JsonB = result.get(8).unwrap().unwrap(); - let analysis: Option> = - Some(serde_json::from_value(jsonb.0).unwrap()); + let analysis: Option> = Some(serde_json::from_value(jsonb.0).unwrap()); let mut s = Snapshot { id: result.get(1).unwrap().unwrap(), @@ -505,8 +498,7 @@ impl Snapshot { // Validate table exists. let (schema_name, table_name) = Self::fully_qualified_table(relation_name); - let preprocessors: HashMap = - serde_json::from_value(preprocess.0).expect("is valid"); + let preprocessors: HashMap = serde_json::from_value(preprocess.0).expect("is valid"); Spi::connect(|mut client| { let mut columns: Vec = Vec::new(); @@ -674,9 +666,7 @@ impl Snapshot { } pub(crate) fn first_label(&self) -> &Column { - self.labels() - .find(|l| l.name == self.y_column_name[0]) - .unwrap() + self.labels().find(|l| l.name == self.y_column_name[0]).unwrap() } pub(crate) fn num_classes(&self) -> usize { @@ -716,9 +706,12 @@ impl Snapshot { match schema_name { None => { - let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) - ]).unwrap().unwrap(); + let table_count = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public'", + vec![(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())], + ) + .unwrap() + .unwrap(); let error = format!("Relation \"{}\" could not be found in the public schema. Please specify the table schema, e.g. pgml.{}", table_name, table_name); @@ -730,18 +723,19 @@ impl Snapshot { } Some(schema_name) => { - let exists = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), schema_name.clone().into_datum()), - ]).unwrap(); + let exists = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2", + vec![ + (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), schema_name.clone().into_datum()), + ], + ) + .unwrap(); if exists == Some(1) { (schema_name, table_name) } else { - error!( - "Relation \"{}\".\"{}\" doesn't exist", - schema_name, table_name - ); + error!("Relation \"{}\".\"{}\" doesn't exist", schema_name, table_name); } } } @@ -818,12 +812,10 @@ impl Snapshot { }; match column.pg_type.as_str() { - "bpchar" | "text" | "varchar" => { - match row[column.position].value::().unwrap() { - Some(text) => vector.push(text), - None => error!("NULL training text is not handled"), - } - } + "bpchar" | "text" | "varchar" => match row[column.position].value::().unwrap() { + Some(text) => vector.push(text), + None => error!("NULL training text is not handled"), + }, _ => error!("only text type columns are supported"), } } @@ -906,24 +898,15 @@ impl Snapshot { } let mut analysis = IndexMap::new(); - analysis.insert( - "samples".to_string(), - numeric_encoded_dataset.num_rows as f32, - ); + analysis.insert("samples".to_string(), numeric_encoded_dataset.num_rows as f32); self.analysis = Some(analysis); // Record the analysis Spi::run_with_args( "UPDATE pgml.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(vec![ - ( - PgBuiltInOids::JSONBOID.oid(), - JsonB(json!(self.analysis)).into_datum(), - ), - ( - PgBuiltInOids::JSONBOID.oid(), - JsonB(json!(self.columns)).into_datum(), - ), + (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.analysis)).into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.columns)).into_datum()), (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), ]), ) @@ -1001,14 +984,19 @@ impl Snapshot { // Categorical encoding types Some(categories) => { let key = match column.pg_type.as_str() { - "bool" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "int2" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "int4" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "int8" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "float4" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "float8" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "bpchar" | "text" | "varchar" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - _ => error!("Unhandled type for categorical variable: {} {:?}", column.name, column.pg_type) + "bool" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "int2" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "int4" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "int8" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "float4" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "float8" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "bpchar" | "text" | "varchar" => { + row[column.position].value::().unwrap().map(|v| v.to_string()) + } + _ => error!( + "Unhandled type for categorical variable: {} {:?}", + column.name, column.pg_type + ), }; let key = key.unwrap_or_else(|| NULL_CATEGORY_KEY.to_string()); if i < num_train_rows { @@ -1018,16 +1006,18 @@ impl Snapshot { NULL_CATEGORY_KEY => 0_f32, // NULL values are always Category 0 _ => match &column.preprocessor.encode { Encode::target | Encode::native | Encode::one_hot { .. } => len as f32, - Encode::ordinal(values) => match values.iter().position(|v| v == key.as_str()) { - Some(i) => (i + 1) as f32, - None => error!("value is not present in ordinal: {:?}. Valid values: {:?}", key, values), + Encode::ordinal(values) => { + match values.iter().position(|v| v == key.as_str()) { + Some(i) => (i + 1) as f32, + None => error!( + "value is not present in ordinal: {:?}. Valid values: {:?}", + key, values + ), + } } - } + }, }; - Category { - value, - members: 0 - } + Category { value, members: 0 } }); category.members += 1; vector.push(category.value); @@ -1088,9 +1078,13 @@ impl Snapshot { vector.push(j as f32) } } - _ => error!("Unhandled type for quantitative array column: {} {:?}", column.name, column.pg_type) + _ => error!( + "Unhandled type for quantitative array column: {} {:?}", + column.name, column.pg_type + ), } - } else { // scalar + } else { + // scalar let float = match column.pg_type.as_str() { "bool" => row[column.position].value::().unwrap().map(|v| v as u8 as f32), "int2" => row[column.position].value::().unwrap().map(|v| v as f32), @@ -1098,7 +1092,10 @@ impl Snapshot { "int8" => row[column.position].value::().unwrap().map(|v| v as f32), "float4" => row[column.position].value::().unwrap(), "float8" => row[column.position].value::().unwrap().map(|v| v as f32), - _ => error!("Unhandled type for quantitative scalar column: {} {:?}", column.name, column.pg_type) + _ => error!( + "Unhandled type for quantitative scalar column: {} {:?}", + column.name, column.pg_type + ), }; match float { Some(f) => vector.push(f), @@ -1114,7 +1111,7 @@ impl Snapshot { let num_features = self.num_features(); let num_labels = self.num_labels(); - data = Some(Dataset{ + data = Some(Dataset { x_train, y_train, x_test, @@ -1129,7 +1126,8 @@ impl Snapshot { }); Ok::, i64>(Some(())) // this return type is nonsense - }).unwrap(); + }) + .unwrap(); let data = data.unwrap(); diff --git a/pgml-extension/src/vectors.rs b/pgml-extension/src/vectors.rs index ccaafa28a..b2114b7dd 100644 --- a/pgml-extension/src/vectors.rs +++ b/pgml-extension/src/vectors.rs @@ -115,18 +115,12 @@ fn divide_vector_d(vector: Array, dividend: Array) -> Vec { #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] fn norm_l0_s(vector: Array) -> f32 { - vector - .iter_deny_null() - .map(|a| if a == 0.0 { 0.0 } else { 1.0 }) - .sum() + vector.iter_deny_null().map(|a| if a == 0.0 { 0.0 } else { 1.0 }).sum() } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] fn norm_l0_d(vector: Array) -> f64 { - vector - .iter_deny_null() - .map(|a| if a == 0.0 { 0.0 } else { 1.0 }) - .sum() + vector.iter_deny_null().map(|a| if a == 0.0 { 0.0 } else { 1.0 }).sum() } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] @@ -334,11 +328,7 @@ impl Aggregate for SumS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state<'a>( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state<'a>(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -356,11 +346,7 @@ impl Aggregate for SumS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -397,11 +383,7 @@ impl Aggregate for SumD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -419,11 +401,7 @@ impl Aggregate for SumD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -460,11 +438,7 @@ impl Aggregate for MaxAbsS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -484,11 +458,7 @@ impl Aggregate for MaxAbsS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -527,11 +497,7 @@ impl Aggregate for MaxAbsD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -551,11 +517,7 @@ impl Aggregate for MaxAbsD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -594,11 +556,7 @@ impl Aggregate for MaxS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -618,11 +576,7 @@ impl Aggregate for MaxS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -661,11 +615,7 @@ impl Aggregate for MaxD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -685,11 +635,7 @@ impl Aggregate for MaxD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -728,11 +674,7 @@ impl Aggregate for MinS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -752,11 +694,7 @@ impl Aggregate for MinS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -795,11 +733,7 @@ impl Aggregate for MinD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -819,11 +753,7 @@ impl Aggregate for MinD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -862,11 +792,7 @@ impl Aggregate for MinAbsS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -886,11 +812,7 @@ impl Aggregate for MinAbsS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -929,11 +851,7 @@ impl Aggregate for MinAbsD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -953,11 +871,7 @@ impl Aggregate for MinAbsD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -1043,65 +957,57 @@ mod tests { #[pg_test] fn test_add_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.add(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_add_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.add(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_subtract_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_subtract_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.subtract(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.multiply(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.multiply(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([1.0, 4.0, 9.0].to_vec()))); } #[pg_test] fn test_divide_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.divide(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } #[pg_test] fn test_divide_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.divide(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } @@ -1178,9 +1084,7 @@ mod tests { let result = Spi::get_one::>("SELECT pgml.normalize_l1(ARRAY[1,2,3]::float8[])"); assert_eq!( result, - Ok(Some( - [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() - )) + Ok(Some([0.16666666666666666, 0.3333333333333333, 0.5].to_vec())) ); } @@ -1217,67 +1121,48 @@ mod tests { #[pg_test] fn test_normalize_max_d() { let result = Spi::get_one::>("SELECT pgml.normalize_max(ARRAY[1,2,3]::float8[])"); - assert_eq!( - result, - Ok(Some([0.3333333333333333, 0.6666666666666666, 1.0].to_vec())) - ); + assert_eq!(result, Ok(Some([0.3333333333333333, 0.6666666666666666, 1.0].to_vec()))); } #[pg_test] fn test_distance_l1_s() { - let result = Spi::get_one::( - "SELECT pgml.distance_l1(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l1_d() { - let result = Spi::get_one::( - "SELECT pgml.distance_l1(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_s() { - let result = Spi::get_one::( - "SELECT pgml.distance_l2(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_d() { - let result = Spi::get_one::( - "SELECT pgml.distance_l2(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_dot_product_s() { - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); assert_eq!(result, Ok(Some(14.0))); - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[2,3,4]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[2,3,4]::float4[])"); assert_eq!(result, Ok(Some(20.0))); } #[pg_test] fn test_dot_product_d() { - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); assert_eq!(result, Ok(Some(14.0))); - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[2,3,4]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[2,3,4]::float8[])"); assert_eq!(result, Ok(Some(20.0))); } @@ -1299,7 +1184,10 @@ mod tests { let want = 0.9925833; assert!((got - want).abs() < F32_TOLERANCE); - let got = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float4[], ARRAY[0,0,1,1,0,1,1]::float4[])").unwrap() + let got = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float4[], ARRAY[0,0,1,1,0,1,1]::float4[])", + ) + .unwrap() .unwrap(); let want = 0.4472136; assert!((got - want).abs() < F32_TOLERANCE); @@ -1323,7 +1211,11 @@ mod tests { let want = 0.9925833339709303; assert!((got - want).abs() < F64_TOLERANCE); - let got = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float8[], ARRAY[0,0,1,1,0,1,1]::float8[])").unwrap().unwrap(); + let got = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float8[], ARRAY[0,0,1,1,0,1,1]::float8[])", + ) + .unwrap() + .unwrap(); let want = 0.4472135954999579; assert!((got - want).abs() < F64_TOLERANCE); } From 936b1cda13adcc349458a9822d1fa1a16c7b150d Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 3 Jan 2024 12:43:59 -0800 Subject: [PATCH 4/4] newline --- pgml-extension/rustfmt.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/rustfmt.toml b/pgml-extension/rustfmt.toml index 3ccd3c986..94ac875fa 100644 --- a/pgml-extension/rustfmt.toml +++ b/pgml-extension/rustfmt.toml @@ -1 +1 @@ -max_width=120 \ No newline at end of file +max_width=120