diff --git a/.github/workflows/javascript-sdk.yml b/.github/workflows/javascript-sdk.yml index 8e929976e..63d84e418 100644 --- a/.github/workflows/javascript-sdk.yml +++ b/.github/workflows/javascript-sdk.yml @@ -58,7 +58,7 @@ jobs: - neon-out-name: "aarch64-unknown-linux-gnu-index.node" os: "buildjet-4vcpu-ubuntu-2204-arm" runs-on: ubuntu-latest - container: ubuntu:16.04 + container: quay.io/pypa/manylinux2014_x86_64 defaults: run: working-directory: pgml-sdks/pgml/javascript @@ -66,9 +66,7 @@ jobs: - uses: actions/checkout@v3 - name: Install dependencies run: | - apt update - apt-get -y install curl - apt-get -y install build-essential + yum install -y perl-IPC-Cmd - uses: actions-rs/toolchain@v1 with: toolchain: stable diff --git a/pgml-dashboard/Cargo.lock b/pgml-dashboard/Cargo.lock index f633d6673..6d9483caf 100644 --- a/pgml-dashboard/Cargo.lock +++ b/pgml-dashboard/Cargo.lock @@ -212,15 +212,6 @@ dependencies = [ "syn 2.0.32", ] -[[package]] -name = "atoi" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" -dependencies = [ - "num-traits", -] - [[package]] name = "atoi" version = "2.0.0" @@ -757,7 +748,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio", - "parking_lot 0.12.1", + "parking_lot", "signal-hook", "signal-hook-mio", "winapi", @@ -989,26 +980,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "dirs" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" -dependencies = [ - "dirs-sys", -] - -[[package]] -name = "dirs-sys" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - [[package]] name = "dotenv" version = "0.15.0" @@ -1345,17 +1316,6 @@ dependencies = [ "futures-util", ] -[[package]] -name = "futures-intrusive" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" -dependencies = [ - "futures-core", - "lock_api", - "parking_lot 0.11.2", -] - [[package]] name = "futures-intrusive" version = "0.5.0" @@ -1364,7 +1324,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.12.1", + "parking_lot", ] [[package]] @@ -2515,17 +2475,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -2533,21 +2482,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -2609,7 +2544,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgml" -version = "0.10.1" +version = "1.0.0" dependencies = [ "anyhow", "async-trait", @@ -2624,7 +2559,7 @@ dependencies = [ "itertools", "lopdf", "md5", - "parking_lot 0.12.1", + "parking_lot", "regex", "reqwest", "rust_bridge", @@ -2632,7 +2567,7 @@ dependencies = [ "sea-query-binder", "serde", "serde_json", - "sqlx 0.6.3", + "sqlx", "tokio", "tracing", "tracing-subscriber", @@ -2669,7 +2604,7 @@ dependencies = [ "markdown", "num-traits", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "pgml", "pgml-components", "pgvector", @@ -2685,7 +2620,7 @@ dependencies = [ "sentry-log", "serde", "serde_json", - "sqlx 0.7.3", + "sqlx", "tantivy", "time", "tokio", @@ -2702,7 +2637,7 @@ checksum = "a1f4c0c07ceb64a0020f2f0e610cfe51122d2e72723499f0154877b7c76c8c31" dependencies = [ "bytes", "postgres", - "sqlx 0.7.3", + "sqlx", ] [[package]] @@ -3079,17 +3014,6 @@ dependencies = [ "bitflags 1.3.2", ] -[[package]] -name = "redox_users" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" -dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", -] - [[package]] name = "ref-cast" version = "1.0.18" @@ -3239,7 +3163,7 @@ dependencies = [ "memchr", "multer", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "rand", "ref-cast", @@ -3412,18 +3336,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "rustls" -version = "0.20.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" -dependencies = [ - "log", - "ring 0.16.20", - "sct", - "webpki", -] - [[package]] name = "rustls" version = "0.21.10" @@ -3569,14 +3481,15 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.29.1" +version = "0.30.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" +checksum = "4166a1e072292d46dc91f31617c2a1cdaf55a8be4b5c9f4bf2ba248e3ac4999b" dependencies = [ "inherent", "sea-query-attr", "sea-query-derive", "serde_json", + "uuid", ] [[package]] @@ -3593,13 +3506,14 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", - "sqlx 0.6.3", + "sqlx", + "uuid", ] [[package]] @@ -4031,84 +3945,19 @@ dependencies = [ "unicode_categories", ] -[[package]] -name = "sqlx" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" -dependencies = [ - "sqlx-core 0.6.3", - "sqlx-macros 0.6.3", -] - [[package]] name = "sqlx" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" dependencies = [ - "sqlx-core 0.7.3", - "sqlx-macros 0.7.3", + "sqlx-core", + "sqlx-macros", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", ] -[[package]] -name = "sqlx-core" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" -dependencies = [ - "ahash 0.7.6", - "atoi 1.0.0", - "base64 0.13.1", - "bitflags 1.3.2", - "byteorder", - "bytes", - "crc", - "crossbeam-queue", - "dirs", - "dotenvy", - "either", - "event-listener", - "futures-channel", - "futures-core", - "futures-intrusive 0.4.2", - "futures-util", - "hashlink", - "hex", - "hkdf", - "hmac", - "indexmap 1.9.3", - "itoa", - "libc", - "log", - "md-5", - "memchr", - "once_cell", - "paste", - "percent-encoding", - "rand", - "rustls 0.20.8", - "rustls-pemfile", - "serde", - "serde_json", - "sha1", - "sha2", - "smallvec", - "sqlformat", - "sqlx-rt", - "stringprep", - "thiserror", - "time", - "tokio-stream", - "url", - "uuid", - "webpki-roots 0.22.6", - "whoami", -] - [[package]] name = "sqlx-core" version = "0.7.3" @@ -4116,7 +3965,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" dependencies = [ "ahash 0.8.7", - "atoi 2.0.0", + "atoi", "bigdecimal", "byteorder", "bytes", @@ -4127,7 +3976,7 @@ dependencies = [ "event-listener", "futures-channel", "futures-core", - "futures-intrusive 0.5.0", + "futures-intrusive", "futures-io", "futures-util", "hashlink", @@ -4138,7 +3987,7 @@ dependencies = [ "once_cell", "paste", "percent-encoding", - "rustls 0.21.10", + "rustls", "rustls-pemfile", "serde", "serde_json", @@ -4152,27 +4001,7 @@ dependencies = [ "tracing", "url", "uuid", - "webpki-roots 0.25.4", -] - -[[package]] -name = "sqlx-macros" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" -dependencies = [ - "dotenvy", - "either", - "heck", - "once_cell", - "proc-macro2", - "quote", - "serde_json", - "sha2", - "sqlx-core 0.6.3", - "sqlx-rt", - "syn 1.0.109", - "url", + "webpki-roots", ] [[package]] @@ -4183,7 +4012,7 @@ checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" dependencies = [ "proc-macro2", "quote", - "sqlx-core 0.7.3", + "sqlx-core", "sqlx-macros-core", "syn 1.0.109", ] @@ -4205,7 +4034,7 @@ dependencies = [ "serde", "serde_json", "sha2", - "sqlx-core 0.7.3", + "sqlx-core", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -4221,7 +4050,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ - "atoi 2.0.0", + "atoi", "base64 0.21.4", "bigdecimal", "bitflags 2.3.3", @@ -4251,7 +4080,7 @@ dependencies = [ "sha1", "sha2", "smallvec", - "sqlx-core 0.7.3", + "sqlx-core", "stringprep", "thiserror", "time", @@ -4266,7 +4095,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" dependencies = [ - "atoi 2.0.0", + "atoi", "base64 0.21.4", "bigdecimal", "bitflags 2.3.3", @@ -4294,7 +4123,7 @@ dependencies = [ "sha1", "sha2", "smallvec", - "sqlx-core 0.7.3", + "sqlx-core", "stringprep", "thiserror", "time", @@ -4303,35 +4132,24 @@ dependencies = [ "whoami", ] -[[package]] -name = "sqlx-rt" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" -dependencies = [ - "once_cell", - "tokio", - "tokio-rustls", -] - [[package]] name = "sqlx-sqlite" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" dependencies = [ - "atoi 2.0.0", + "atoi", "flume", "futures-channel", "futures-core", "futures-executor", - "futures-intrusive 0.5.0", + "futures-intrusive", "futures-util", "libsqlite3-sys", "log", "percent-encoding", "serde", - "sqlx-core 0.7.3", + "sqlx-core", "time", "tracing", "url", @@ -4371,7 +4189,7 @@ checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" dependencies = [ "new_debug_unreachable", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "phf_shared 0.10.0", "precomputed-hash", "serde", @@ -4714,7 +4532,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2 0.4.9", @@ -4767,7 +4585,7 @@ dependencies = [ "futures-channel", "futures-util", "log", - "parking_lot 0.12.1", + "parking_lot", "percent-encoding", "phf 0.11.2", "pin-project-lite", @@ -4778,17 +4596,6 @@ dependencies = [ "tokio-util", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls 0.20.8", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -5311,25 +5118,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" -dependencies = [ - "ring 0.16.20", - "untrusted 0.7.1", -] - -[[package]] -name = "webpki-roots" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" -dependencies = [ - "webpki", -] - [[package]] name = "webpki-roots" version = "0.25.4" @@ -5347,10 +5135,6 @@ name = "whoami" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" -dependencies = [ - "wasm-bindgen", - "web-sys", -] [[package]] name = "winapi" diff --git a/pgml-dashboard/src/api/chatbot.rs b/pgml-dashboard/src/api/chatbot.rs index d5f439902..288b1df43 100644 --- a/pgml-dashboard/src/api/chatbot.rs +++ b/pgml-dashboard/src/api/chatbot.rs @@ -169,7 +169,6 @@ enum KnowledgeBase { } impl KnowledgeBase { - // The topic and knowledge base are the same for now but may be different later fn topic(&self) -> &'static str { match self { Self::PostgresML => "PostgresML", @@ -181,10 +180,10 @@ impl KnowledgeBase { fn collection(&self) -> &'static str { match self { - Self::PostgresML => "PostgresML", - Self::PyTorch => "PyTorch", - Self::Rust => "Rust", - Self::PostgreSQL => "PostgreSQL", + Self::PostgresML => "PostgresML_0", + Self::PyTorch => "PyTorch_0", + Self::Rust => "Rust_0", + Self::PostgreSQL => "PostgreSQL_0", } } } @@ -396,31 +395,29 @@ pub async fn chatbot_get_history(user: User) -> Json { async fn do_chatbot_get_history(user: &User, limit: usize) -> anyhow::Result> { let history_collection = Collection::new( - "ChatHistory", + "ChatHistory_0", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); + )?; let mut messages = history_collection .get_documents(Some( json!({ "limit": limit, "order_by": {"timestamp": "desc"}, "filter": { - "metadata": { - "$and" : [ - { - "$or": - [ - {"role": {"$eq": ChatRole::Bot}}, - {"role": {"$eq": ChatRole::User}} - ] - }, - { - "user_id": { - "$eq": user.chatbot_session_id - } + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id } - ] - } + } + ] } }) @@ -521,64 +518,64 @@ async fn process_message( knowledge_base, ); - let pipeline = Pipeline::new("v1", None, None, None); + let mut pipeline = Pipeline::new("v1", None)?; let collection = knowledge_base.collection(); - let collection = Collection::new( + let mut collection = Collection::new( collection, Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); + )?; let context = collection - .query() - .vector_recall( - &data.question, - &pipeline, - Some( - json!({ - "instruction": "Represent the Wikipedia question for retrieving supporting documents: " - }) - .into(), - ), + .vector_search( + serde_json::json!({ + "query": { + "fields": { + "text": { + "query": &data.question, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + } + }, + } + }}) + .into(), + &mut pipeline, ) - .limit(5) - .fetch_all() .await? .into_iter() - .map(|(_, context, metadata)| format!("\n\n#### Document {}: \n{}\n\n", metadata["id"], context)) + .map(|v| format!("\n\n#### Document {}: \n{}\n\n", v["document"]["id"], v["chunk"])) .collect::>() - .join("\n"); + .join(""); let history_collection = Collection::new( - "ChatHistory", + "ChatHistory_0", Some(std::env::var("CHATBOT_DATABASE_URL").expect("CHATBOT_DATABASE_URL not set")), - ); + )?; let mut messages = history_collection .get_documents(Some( json!({ "limit": 5, "order_by": {"timestamp": "desc"}, "filter": { - "metadata": { - "$and" : [ - { - "$or": - [ - {"role": {"$eq": ChatRole::Bot}}, - {"role": {"$eq": ChatRole::User}} - ] - }, - { - "user_id": { - "$eq": user.chatbot_session_id - } - }, - { - "knowledge_base": { - "$eq": knowledge_base - } - }, - // This is where we would match on the model if we wanted to - ] - } + "$and" : [ + { + "$or": + [ + {"role": {"$eq": ChatRole::Bot}}, + {"role": {"$eq": ChatRole::User}} + ] + }, + { + "user_id": { + "$eq": user.chatbot_session_id + } + }, + { + "knowledge_base": { + "$eq": knowledge_base + } + }, + // This is where we would match on the model if we wanted to + ] } }) diff --git a/pgml-dashboard/src/api/cms.rs b/pgml-dashboard/src/api/cms.rs index 2048b24c8..ee1060d02 100644 --- a/pgml-dashboard/src/api/cms.rs +++ b/pgml-dashboard/src/api/cms.rs @@ -559,9 +559,8 @@ impl Collection { } #[get("/search?", rank = 20)] -async fn search(query: &str, index: &State) -> ResponseOk { - let results = index.search(query).unwrap(); - +async fn search(query: &str, site_search: &State) -> ResponseOk { + let results = site_search.search(query, None).await.expect("Error performing search"); ResponseOk( Template(Search { query: query.to_string(), @@ -711,9 +710,9 @@ pub fn routes() -> Vec { #[cfg(test)] mod test { use super::*; - use crate::utils::markdown::{options, MarkdownHeadings, SyntaxHighlighter}; + use crate::utils::markdown::options; use regex::Regex; - use rocket::http::{ContentType, Cookie, Status}; + use rocket::http::Status; use rocket::local::asynchronous::Client; use rocket::{Build, Rocket}; @@ -779,7 +778,7 @@ This is the end of the markdown async fn rocket() -> Rocket { dotenv::dotenv().ok(); rocket::build() - .manage(crate::utils::markdown::SearchIndex::open().unwrap()) + // .manage(crate::utils::markdown::SearchIndex::open().unwrap()) .mount("/", crate::api::cms::routes()) } diff --git a/pgml-dashboard/src/main.rs b/pgml-dashboard/src/main.rs index f09b21d8b..ce38c5b8d 100644 --- a/pgml-dashboard/src/main.rs +++ b/pgml-dashboard/src/main.rs @@ -92,14 +92,20 @@ async fn main() { // it's important to hang on to sentry so it isn't dropped and stops reporting let _sentry = configure_reporting().await; - markdown::SearchIndex::build().await.unwrap(); + let site_search = markdown::SiteSearch::new() + .await + .expect("Error initializing site search"); + let mut site_search_copy = site_search.clone(); + tokio::spawn(async move { + site_search_copy.build().await.expect("Error building site search"); + }); pgml_dashboard::migrate(guards::Cluster::default(None).pool()) .await .unwrap(); let _ = rocket::build() - .manage(markdown::SearchIndex::open().unwrap()) + .manage(site_search) .mount("/", rocket::routes![index, error]) .mount("/dashboard/static", FileServer::from(config::static_dir())) .mount("/dashboard", pgml_dashboard::routes()) @@ -131,8 +137,13 @@ mod test { pgml_dashboard::migrate(Cluster::default(None).pool()).await.unwrap(); + let mut site_search = markdown::SiteSearch::new() + .await + .expect("Error initializing site search"); + site_search.build().await.expect("Error building site search"); + rocket::build() - .manage(markdown::SearchIndex::open().unwrap()) + .manage(site_search) .mount("/", rocket::routes![index, error]) .mount("/dashboard/static", FileServer::from(config::static_dir())) .mount("/dashboard", pgml_dashboard::routes()) diff --git a/pgml-dashboard/src/utils/markdown.rs b/pgml-dashboard/src/utils/markdown.rs index dcd878e3a..55c42b9b1 100644 --- a/pgml-dashboard/src/utils/markdown.rs +++ b/pgml-dashboard/src/utils/markdown.rs @@ -1,8 +1,9 @@ +use crate::api::cms::{DocType, Document}; use crate::{templates::docs::TocLink, utils::config}; - +use anyhow::Context; use std::cell::RefCell; -use std::collections::{HashMap, HashSet}; -use std::path::{Path, PathBuf}; +use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; use anyhow::Result; @@ -10,21 +11,15 @@ use comrak::{ adapters::{HeadingAdapter, HeadingMeta, SyntaxHighlighterAdapter}, arena_tree::Node, nodes::{Ast, AstNode, NodeValue}, - parse_document, Arena, ComrakExtensionOptions, ComrakOptions, ComrakRenderOptions, + Arena, ComrakExtensionOptions, ComrakOptions, ComrakRenderOptions, }; use convert_case; use itertools::Itertools; use regex::Regex; -use tantivy::collector::TopDocs; -use tantivy::query::{QueryParser, RegexQuery}; -use tantivy::schema::*; -use tantivy::tokenizer::{LowerCaser, NgramTokenizer, TextAnalyzer}; -use tantivy::{Index, IndexReader, SnippetGenerator}; -use url::Url; - -use std::sync::Mutex; - +use serde::Deserialize; use std::fmt; +use std::sync::Mutex; +use url::Url; pub struct MarkdownHeadings { header_map: Arc>>, @@ -1222,31 +1217,72 @@ pub async fn get_document(path: &PathBuf) -> anyhow::Result { Ok(tokio::fs::read_to_string(path).await?) } +#[derive(Deserialize)] +struct SearchResultWithoutSnippet { + title: String, + contents: String, + path: String, +} + pub struct SearchResult { pub title: String, - pub body: String, pub path: String, pub snippet: String, } -pub struct SearchIndex { - // The index. - pub index: Arc, - - // Index schema (fields). - pub schema: Arc, - - // The index reader, supports concurrent access. - pub reader: Arc, +#[derive(Clone)] +pub struct SiteSearch { + collection: pgml::Collection, + pipeline: pgml::Pipeline, } -impl SearchIndex { - pub fn path() -> PathBuf { - Path::new(&config::search_index_dir()).to_owned() +impl SiteSearch { + pub async fn new() -> anyhow::Result { + let collection = pgml::Collection::new( + "hypercloud-site-search-c-2", + Some( + std::env::var("SITE_SEARCH_DATABASE_URL") + .context("Please set the `SITE_SEARCH_DATABASE_URL` environment variable")?, + ), + )?; + let pipeline = pgml::Pipeline::new( + "hypercloud-site-search-p-0", + Some( + serde_json::json!({ + "title": { + "full_text_search": { + "configuration": "english" + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + } + }, + "contents": { + "splitter": { + "model": "recursive_character" + }, + "full_text_search": { + "configuration": "english" + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + } + } + }) + .into(), + ), + )?; + Ok(Self { collection, pipeline }) } pub fn documents() -> Vec { - // TODO imrpove this .display().to_string() + // TODO improve this .display().to_string() let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string()).expect("glob failed"); let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string()).expect("glob failed"); guides @@ -1255,224 +1291,106 @@ impl SearchIndex { .collect() } - pub fn schema() -> Schema { - // TODO: Make trigram title index - // and full text body index, and use trigram only if body gets nothing. - let mut schema_builder = Schema::builder(); - let title_field_indexing = TextFieldIndexing::default() - .set_tokenizer("ngram3") - .set_index_option(IndexRecordOption::WithFreqsAndPositions); - let title_options = TextOptions::default() - .set_indexing_options(title_field_indexing) - .set_stored(); - - schema_builder.add_text_field("title", title_options.clone()); - schema_builder.add_text_field("title_regex", TEXT | STORED); - schema_builder.add_text_field("body", TEXT | STORED); - schema_builder.add_text_field("path", STORED); - - schema_builder.build() - } - - pub async fn build() -> tantivy::Result<()> { - // Remove existing index. - let _ = std::fs::remove_dir_all(Self::path()); - std::fs::create_dir(Self::path()).unwrap(); - - let index = tokio::task::spawn_blocking(move || -> tantivy::Result { - Index::create_in_dir(Self::path(), Self::schema()) - }) - .await - .unwrap()?; - - let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - - index.tokenizers().register("ngram3", ngram); - - let schema = Self::schema(); - let mut index_writer = index.writer(50_000_000)?; - - for path in Self::documents().into_iter() { - let text = get_document(&path).await.unwrap(); - - let arena = Arena::new(); - let root = parse_document(&arena, &text, &options()); - let title_text = get_title(root).unwrap(); - let body_text = get_text(root).unwrap().into_iter().join(" "); - - let title_field = schema.get_field("title").unwrap(); - let body_field = schema.get_field("body").unwrap(); - let path_field = schema.get_field("path").unwrap(); - let title_regex_field = schema.get_field("title_regex").unwrap(); - - info!("found path: {path}", path = path.display()); - let path = path - .to_str() - .unwrap() - .to_string() - .split("content") - .last() - .unwrap() - .to_string() - .replace("README", "") - .replace(&config::cms_dir().display().to_string(), ""); - let mut doc = Document::default(); - doc.add_text(title_field, &title_text); - doc.add_text(body_field, &body_text); - doc.add_text(path_field, &path); - doc.add_text(title_regex_field, &title_text); - - index_writer.add_document(doc)?; - } - - tokio::task::spawn_blocking(move || -> tantivy::Result { index_writer.commit() }) - .await - .unwrap()?; - - Ok(()) - } - - pub fn open() -> tantivy::Result { - let path = Self::path(); - - if !path.exists() { - std::fs::create_dir(&path).expect("failed to create search_index directory, is the filesystem writable?"); - } - - let index = match tantivy::Index::open_in_dir(&path) { - Ok(index) => index, - Err(err) => { - warn!( - "Failed to open Tantivy index in '{}', creating an empty one, error: {}", - path.display(), - err - ); - Index::create_in_dir(&path, Self::schema())? - } - }; - - let reader = index.reader_builder().try_into()?; - - let ngram = TextAnalyzer::from(NgramTokenizer::new(3, 3, false)).filter(LowerCaser); - - index.tokenizers().register("ngram3", ngram); - - Ok(SearchIndex { - index: Arc::new(index), - schema: Arc::new(Self::schema()), - reader: Arc::new(reader), - }) - } - - pub fn search(&self, query_string: &str) -> tantivy::Result> { - let mut results = Vec::new(); - let searcher = self.reader.searcher(); - let title_field = self.schema.get_field("title").unwrap(); - let body_field = self.schema.get_field("body").unwrap(); - let path_field = self.schema.get_field("path").unwrap(); - let title_regex_field = self.schema.get_field("title_regex").unwrap(); - - // Search using: - // - // 1. Full text search on the body - // 2. Trigrams on the title - let query_parser = QueryParser::for_index(&self.index, vec![title_field, body_field]); - let query = match query_parser.parse_query(query_string) { - Ok(query) => query, - Err(err) => { - warn!("Query parse error: {}", err); - return Ok(Vec::new()); - } - }; - - let mut top_docs = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - - // If that's not enough, search using prefix search on the title. - if top_docs.len() < 10 { - let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), title_regex_field) { - Ok(query) => query, - Err(err) => { - warn!("Query regex error: {}", err); - return Ok(Vec::new()); + pub async fn search(&self, query: &str, doc_type: Option) -> anyhow::Result> { + let mut search = serde_json::json!({ + "query": { + // "full_text_search": { + // "title": { + // "query": query, + // "boost": 4.0 + // }, + // "contents": { + // "query": query + // } + // }, + "semantic_search": { + "title": { + "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + "boost": 10.0 + }, + "contents": { + "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + "boost": 1.0 + } } - }; - - let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - top_docs.extend(more_results); - } - - // Oh jeez ok - if top_docs.len() < 10 { - let query = match RegexQuery::from_pattern(&format!("{}.*", query_string), body_field) { - Ok(query) => query, - Err(err) => { - warn!("Query regex error: {}", err); - return Ok(Vec::new()); + }, + "limit": 10 + }); + if let Some(doc_type) = doc_type { + search["query"]["filter"] = serde_json::json!({ + "doc_type": { + "$eq": doc_type } - }; - - let more_results = searcher.search(&query, &TopDocs::with_limit(10)).unwrap(); - top_docs.extend(more_results); - } - - // Generate snippets for the FTS query. - let snippet_generator = SnippetGenerator::create(&searcher, &*query, body_field)?; - - let mut dedup = HashSet::new(); - - for (_score, doc_address) in top_docs { - let retrieved_doc = searcher.doc(doc_address)?; - let snippet = snippet_generator.snippet_from_doc(&retrieved_doc); - let path = retrieved_doc - .get_first(path_field) - .unwrap() - .as_text() - .unwrap() - .to_string() - .replace(".md", "") - .replace(&config::static_dir().display().to_string(), ""); - - // Dedup results from prefix search and full text search. - let new = dedup.insert(path.clone()); - - if !new { - continue; - } - - let title = retrieved_doc - .get_first(title_field) - .unwrap() - .as_text() - .unwrap() - .to_string(); - let body = retrieved_doc - .get_first(body_field) - .unwrap() - .as_text() - .unwrap() - .to_string(); - - let snippet = if snippet.is_empty() { - body.split(' ').take(20).collect::>().join(" ") + " ..." - } else { - "... ".to_string() + &snippet.to_html() + " ..." - }; - - results.push(SearchResult { - title, - body, - path, - snippet, }); } + let results = self.collection.search_local(search.into(), &self.pipeline).await?; + + results["results"] + .as_array() + .context("Error getting results from search")? + .into_iter() + .map(|r| { + let SearchResultWithoutSnippet { title, contents, path } = + serde_json::from_value(r["document"].clone())?; + let path = path + .replace(".md", "") + .replace(&config::static_dir().display().to_string(), ""); + Ok(SearchResult { + title, + path, + snippet: contents.split(' ').take(20).collect::>().join(" ") + " ...", + }) + }) + .collect() + } - Ok(results) + pub async fn build(&mut self) -> anyhow::Result<()> { + self.collection.add_pipeline(&mut self.pipeline).await?; + let documents: Vec = futures::future::try_join_all( + Self::get_document_paths()? + .into_iter() + .map(|path| async move { Document::from_path(&path).await }), + ) + .await?; + let documents: Vec = documents + .into_iter() + .map(|d| { + let mut document_json = serde_json::to_value(d).unwrap(); + document_json["id"] = document_json["path"].clone(); + document_json["path"] = serde_json::json!(document_json["path"] + .as_str() + .unwrap() + .split("content") + .last() + .unwrap() + .to_string() + .replace("README", "") + .replace(&config::cms_dir().display().to_string(), "")); + document_json.into() + }) + .collect(); + self.collection.upsert_documents(documents, None).await + } + + fn get_document_paths() -> anyhow::Result> { + // TODO imrpove this .display().to_string() + let guides = glob::glob(&config::cms_dir().join("docs/**/*.md").display().to_string())?; + let blogs = glob::glob(&config::cms_dir().join("blog/**/*.md").display().to_string())?; + Ok(guides + .chain(blogs) + .map(|path| path.expect("glob path failed")) + .collect()) } } #[cfg(test)] mod test { - use super::*; use crate::utils::markdown::parser; #[test] diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 131380b9d..e651e5969 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -3,47 +3,47 @@ version = 3 [[package]] -name = "adler" -version = "1.0.2" +name = "addr2line" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] [[package]] -name = "ahash" -version = "0.7.6" +name = "adler" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", + "getrandom", "once_cell", "version_check", + "zerocopy", ] [[package]] name = "aho-corasick" -version = "1.0.2" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] [[package]] name = "allocator-api2" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4f263788a35611fba42eb41ff811c5d0360c58b97402570312a350736e2542e" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "android-tzdata" @@ -62,9 +62,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.4" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" dependencies = [ "anstyle", "anstyle-parse", @@ -76,64 +76,74 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anstyle-parse" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.1" +version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" dependencies = [ "anstyle", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "anyhow" -version = "1.0.71" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" [[package]] name = "async-trait" -version = "0.1.71" +version = "0.1.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" +checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "atoi" -version = "1.0.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c57d12312ff59c811c0643f4d80830505833c9ffaebd193d819392b265be8e" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" dependencies = [ "num-traits", ] +[[package]] +name = "atomic-write-file" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edcdbedc2236483ab103a53415653d6b4442ea6141baf1ffa85df29635e88436" +dependencies = [ + "nix", + "rand", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -141,16 +151,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] -name = "base64" -version = "0.13.1" +name = "backtrace" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] [[package]] name = "base64" -version = "0.21.2" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64ct" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bitflags" @@ -160,9 +185,12 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +dependencies = [ + "serde", +] [[package]] name = "block-buffer" @@ -175,27 +203,30 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] [[package]] name = "cfg-if" @@ -205,24 +236,23 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.26" +version = "0.4.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", - "time 0.1.45", "wasm-bindgen", - "winapi", + "windows-targets 0.52.0", ] [[package]] name = "clap" -version = "4.4.10" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" dependencies = [ "clap_builder", "clap_derive", @@ -230,9 +260,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.9" +version = "4.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" dependencies = [ "anstream", "anstyle", @@ -249,7 +279,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -266,33 +296,38 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "colored" -version = "2.0.4" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" dependencies = [ - "is-terminal", "lazy_static", "windows-sys 0.48.0", ] [[package]] name = "console" -version = "0.15.7" +version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" dependencies = [ "encode_unicode", "lazy_static", "libc", "unicode-width", - "windows-sys 0.45.0", + "windows-sys 0.52.0", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "core-foundation" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -300,15 +335,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" dependencies = [ "libc", ] @@ -324,9 +359,9 @@ dependencies = [ [[package]] name = "crc-catalog" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cace84e55f07e7301bae1c519df89cdad8cc3cd868413d3fdbdeca9ff3db484" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] name = "crc32fast" @@ -339,46 +374,37 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.15" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", - "memoffset 0.9.0", - "scopeguard", ] [[package]] name = "crossbeam-queue" -version = "0.3.8" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" dependencies = [ - "cfg-if", "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.16" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" -dependencies = [ - "cfg-if", -] +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" [[package]] name = "crossterm" @@ -390,7 +416,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio", - "parking_lot 0.12.1", + "parking_lot", "signal-hook", "signal-hook-mio", "winapi", @@ -417,12 +443,12 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.4.0" +version = "3.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a011bbe2c35ce9c1f143b7af6f94f29a167beb4cd1d29e6740ce836f723120e" +checksum = "b467862cc8610ca6fc9a1532d7777cee0804e678ab45410897b9396495994a0b" dependencies = [ "nix", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -461,34 +487,35 @@ dependencies = [ ] [[package]] -name = "digest" -version = "0.10.7" +name = "der" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" dependencies = [ - "block-buffer", - "crypto-common", - "subtle", + "const-oid", + "pem-rfc7468", + "zeroize", ] [[package]] -name = "dirs" -version = "4.0.0" +name = "deranged" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ - "dirs-sys", + "powerfmt", ] [[package]] -name = "dirs-sys" -version = "0.3.7" +name = "digest" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "libc", - "redox_users", - "winapi", + "block-buffer", + "const-oid", + "crypto-common", + "subtle", ] [[package]] @@ -505,9 +532,12 @@ checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -517,32 +547,38 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.32" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" -version = "0.3.1" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "errno-dragonfly", "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "etcetera" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" dependencies = [ - "cc", - "libc", + "cfg-if", + "home", + "windows-sys 0.48.0", ] [[package]] @@ -553,23 +589,37 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "fastrand" -version = "1.9.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + +[[package]] +name = "finl_unicode" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" [[package]] name = "flate2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6c98ee8095e9d1dcbf2fcc6d95acccb90d1c81db1e44725c6a984b1dbdfb010" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" dependencies = [ "crc32fast", "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -593,18 +643,18 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] [[package]] name = "futures" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", @@ -617,9 +667,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", "futures-sink", @@ -627,15 +677,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" dependencies = [ "futures-core", "futures-task", @@ -644,49 +694,49 @@ dependencies = [ [[package]] name = "futures-intrusive" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a604f7a68fbf8103337523b1fadc8ade7361ee3f112f7c680ad179651616aed5" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.11.2", + "parking_lot", ] [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-channel", "futures-core", @@ -712,20 +762,26 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.10" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", ] +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + [[package]] name = "h2" -version = "0.3.20" +version = "0.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" +checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" dependencies = [ "bytes", "fnv", @@ -742,27 +798,21 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - -[[package]] -name = "hashbrown" -version = "0.14.0" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" dependencies = [ - "ahash 0.8.3", + "ahash", "allocator-api2", ] [[package]] name = "hashlink" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312f66718a2d7789ffef4f4b7b213138ed9f1eb3aa1d0d82fc99f88fb3ffd26f" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.14.0", + "hashbrown", ] [[package]] @@ -776,18 +826,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.3.2" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" [[package]] name = "hex" @@ -797,9 +838,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hkdf" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "791a029f6b9fc27657f6f188ec6e5e43f6911f6f878e0dc5501396e09809d437" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" dependencies = [ "hmac", ] @@ -813,11 +854,20 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "http" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" dependencies = [ "bytes", "fnv", @@ -826,9 +876,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -843,15 +893,15 @@ checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "httpdate" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "0.14.27" +version = "0.14.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" dependencies = [ "bytes", "futures-channel", @@ -886,16 +936,16 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.57" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows-core", ] [[package]] @@ -915,9 +965,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -925,19 +975,19 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.3" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ - "autocfg", - "hashbrown 0.12.3", + "equivalent", + "hashbrown", ] [[package]] name = "indicatif" -version = "0.17.6" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b297dc40733f23a0e52728a58fa9489a5b7638a324932de16b41adc3ef80730" +checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" dependencies = [ "console", "instant", @@ -954,13 +1004,13 @@ checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" [[package]] name = "inherent" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" +checksum = "0122b7114117e64a63ac49f752a5ca4624d534c7b1c7de796ac196381cd2d947" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -988,32 +1038,21 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi 0.3.2", - "libc", - "windows-sys 0.48.0", -] - [[package]] name = "ipnet" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" dependencies = [ - "hermit-abi 0.3.2", - "rustix 0.38.3", - "windows-sys 0.48.0", + "hermit-abi", + "rustix", + "windows-sys 0.52.0", ] [[package]] @@ -1025,17 +1064,26 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.64" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" dependencies = [ "wasm-bindgen", ] @@ -1045,12 +1093,15 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin 0.5.2", +] [[package]] name = "libc" -version = "0.2.146" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" @@ -1063,28 +1114,39 @@ dependencies = [ ] [[package]] -name = "linked-hash-map" -version = "0.5.6" +name = "libm" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] -name = "linux-raw-sys" -version = "0.3.8" +name = "libsqlite3-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linked-hash-map" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -1092,9 +1154,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "lopdf" @@ -1111,16 +1173,17 @@ dependencies = [ "md5", "nom", "rayon", - "time 0.3.22", + "time", "weezl", ] [[package]] name = "md-5" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ + "cfg-if", "digest", ] @@ -1132,9 +1195,9 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "memchr" -version = "2.5.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "memoffset" @@ -1145,15 +1208,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "memoffset" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" -dependencies = [ - "autocfg", -] - [[package]] name = "mime" version = "0.3.17" @@ -1168,22 +1222,22 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] [[package]] name = "mio" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "log", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "windows-sys 0.48.0", ] @@ -1257,11 +1311,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.4" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.2", "cfg-if", "libc", ] @@ -1286,22 +1340,67 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", + "libm", ] [[package]] name = "num_cpus" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.2.6", + "hermit-abi", "libc", ] @@ -1311,19 +1410,28 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openssl" -version = "0.10.55" +version = "0.10.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +checksum = "15c9d69dd87a29568d4d017cfe8ec518706046a05184e5aea92d0af890b803c8" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.2", "cfg-if", "foreign-types", "libc", @@ -1340,7 +1448,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -1351,18 +1459,18 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "111.26.0+1.1.1u" +version = "300.2.2+3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efc62c9f12b22b8f5208c23a7200a442b2e5999f8bdf80233852122b5a4f6f37" +checksum = "8bbfad0063610ac26ee79f7484739e2b07555a75c42453b89263830b5c8103bc" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.90" +version = "0.9.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +checksum = "22e1bf214306098e4832460f797824c05d25aacdf896f64a985fb0fd992454ae" dependencies = [ "cc", "libc", @@ -1377,17 +1485,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -1395,51 +1492,46 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", + "parking_lot_core", ] [[package]] name = "parking_lot_core" -version = "0.8.6" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", - "instant", "libc", - "redox_syscall 0.2.16", + "redox_syscall", "smallvec", - "winapi", + "windows-targets 0.48.5", ] [[package]] -name = "parking_lot_core" -version = "0.9.8" +name = "paste" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall 0.3.5", - "smallvec", - "windows-targets 0.48.0", -] +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] -name = "paste" -version = "1.0.12" +name = "pem-rfc7468" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "0.10.0" +version = "1.0.0" dependencies = [ "anyhow", "async-trait", @@ -1451,11 +1543,11 @@ dependencies = [ "indicatif", "inquire", "is-terminal", - "itertools", + "itertools 0.10.5", "lopdf", "md5", "neon", - "parking_lot 0.12.1", + "parking_lot", "pyo3", "pyo3-asyncio", "regex", @@ -1475,9 +1567,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -1485,17 +1577,44 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "portable-atomic" -version = "1.4.2" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "powerfmt" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f32154ba0af3a075eefa1eda8bb414ee928f62303a54ea85b8d6638ff1a6ee9e" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -1505,9 +1624,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.64" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] @@ -1522,8 +1641,8 @@ dependencies = [ "cfg-if", "indoc", "libc", - "memoffset 0.8.0", - "parking_lot 0.12.1", + "memoffset", + "parking_lot", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -1600,9 +1719,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.29" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -1639,9 +1758,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" dependencies = [ "either", "rayon-core", @@ -1649,9 +1768,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -1659,38 +1778,30 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "redox_syscall" -version = "0.3.5" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags 1.3.2", ] [[package]] -name = "redox_users" -version = "0.4.3" +name = "regex" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ - "getrandom", - "redox_syscall 0.2.16", - "thiserror", + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", ] [[package]] -name = "regex" -version = "1.8.4" +name = "regex-automata" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", @@ -1699,17 +1810,17 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ - "base64 0.21.2", + "base64", "bytes", "encoding_rs", "futures-core", @@ -1727,9 +1838,12 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", + "sync_wrapper", + "system-configuration", "tokio", "tokio-native-tls", "tower-service", @@ -1742,17 +1856,36 @@ dependencies = [ [[package]] name = "ring" -version = "0.16.20" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", + "getrandom", "libc", - "once_cell", - "spin", + "spin 0.9.8", "untrusted", - "web-sys", - "winapi", + "windows-sys 0.48.0", +] + +[[package]] +name = "rsa" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e5124fcb30e76a7e79bfee683a2746db83784b86289f6251b54b7950a0dfc" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core", + "signature", + "spki", + "subtle", + "zeroize", ] [[package]] @@ -1770,7 +1903,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -1781,58 +1914,59 @@ dependencies = [ ] [[package]] -name = "rustix" -version = "0.37.26" +name = "rustc-demangle" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84f3f8f960ed3b5a59055428714943298bf3fa2d4a1d53135084e0544829d995" -dependencies = [ - "bitflags 1.3.2", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", -] +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.3" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac5ffa1efe7548069688cd7028f32591853cd7b5b756d41bcffd2353e4fc75b4" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.4.2", "errno", "libc", - "linux-raw-sys 0.4.11", - "windows-sys 0.48.0", + "linux-raw-sys", + "windows-sys 0.52.0", ] [[package]] name = "rustls" -version = "0.20.9" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ - "log", "ring", + "rustls-webpki", "sct", - "webpki", ] [[package]] name = "rustls-pemfile" -version = "1.0.2" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64 0.21.2", + "base64", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", ] [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -1845,24 +1979,24 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "scopeguard" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sct" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ "ring", "untrusted", @@ -1870,14 +2004,15 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.29.1" +version = "0.30.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "332375aa0c555318544beec038b285c75f2dbeecaecb844383419ccf2663868e" +checksum = "4166a1e072292d46dc91f31617c2a1cdaf55a8be4b5c9f4bf2ba248e3ac4999b" dependencies = [ "inherent", "sea-query-attr", "sea-query-derive", "serde_json", + "uuid", ] [[package]] @@ -1894,33 +2029,34 @@ dependencies = [ [[package]] name = "sea-query-binder" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420eb97201b8a5c76351af7b4925ce5571c2ec3827063a0fb8285d239e1621a0" +checksum = "36bbb68df92e820e4d5aeb17b4acd5cc8b5d18b2c36a4dd6f4626aabfa7ab1b9" dependencies = [ "sea-query", "serde_json", "sqlx", + "uuid", ] [[package]] name = "sea-query-derive" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd78f2e0ee8e537e9195d1049b752e0433e2cac125426bccb7b5c3e508096117" +checksum = "25a82fcb49253abcb45cdcb2adf92956060ec0928635eb21b4f7a6d8f25ab0bc" dependencies = [ "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.48", "thiserror", ] [[package]] name = "security-framework" -version = "2.9.1" +version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ "bitflags 1.3.2", "core-foundation", @@ -1931,9 +2067,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" dependencies = [ "core-foundation-sys", "libc", @@ -1956,29 +2092,29 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.181" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d3e73c93c3240c0bda063c239298e633114c69a888c3e37ca8bb33f343e9890" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.181" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be02f6cb0cd3a5ec20bbcfbcbd749f57daddb1a0882dc2e46a6c236c90b977ed" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.113" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" dependencies = [ "itoa", "ryu", @@ -1999,9 +2135,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -2010,9 +2146,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.6" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", @@ -2021,9 +2157,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.4" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" dependencies = [ "lazy_static", ] @@ -2058,29 +2194,39 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "slab" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" dependencies = [ "autocfg", ] [[package]] name = "smallvec" -version = "1.10.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "socket2" -version = "0.4.9" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -2089,119 +2235,251 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sqlformat" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c12bc9199d1db8234678b7051747c07f517cdcf019262d1847b94ec8b1aee3e" +checksum = "ce81b7bd7c4493975347ef60d8c7e8b742d4694f4c49f93e0a12ea263938176c" dependencies = [ - "itertools", + "itertools 0.12.1", "nom", "unicode_categories", ] [[package]] name = "sqlx" -version = "0.6.3" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8de3b03a925878ed54a954f621e64bf55a3c1bd29652d0d1a17830405350188" +checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf" dependencies = [ "sqlx-core", "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", ] [[package]] name = "sqlx-core" -version = "0.6.3" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8241483a83a3f33aa5fff7e7d9def398ff9990b2752b6c6112b83c6d246029" +checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd" dependencies = [ - "ahash 0.7.6", + "ahash", "atoi", - "base64 0.13.1", - "bitflags 1.3.2", "byteorder", "bytes", "crc", "crossbeam-queue", - "dirs", "dotenvy", "either", "event-listener", "futures-channel", "futures-core", "futures-intrusive", + "futures-io", "futures-util", "hashlink", "hex", - "hkdf", - "hmac", "indexmap", - "itoa", - "libc", "log", - "md-5", "memchr", "once_cell", "paste", "percent-encoding", - "rand", "rustls", "rustls-pemfile", "serde", "serde_json", - "sha1", "sha2", "smallvec", "sqlformat", - "sqlx-rt", - "stringprep", "thiserror", - "time 0.3.22", + "time", + "tokio", "tokio-stream", + "tracing", "url", "uuid", "webpki-roots", - "whoami", ] [[package]] name = "sqlx-macros" -version = "0.6.3" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 1.0.109", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9966e64ae989e7e575b19d7265cb79d7fc3cbbdf179835cb0d716f294c2049c9" +checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841" dependencies = [ + "atomic-write-file", "dotenvy", "either", "heck", + "hex", "once_cell", "proc-macro2", "quote", + "serde", "serde_json", "sha2", "sqlx-core", - "sqlx-rt", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", "syn 1.0.109", + "tempfile", + "tokio", "url", ] [[package]] -name = "sqlx-rt" -version = "0.6.3" +name = "sqlx-mysql" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804d3f245f894e61b1e6263c84b23ca675d96753b5abfd5cc8597d86806e8024" +checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "bytes", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", "once_cell", - "tokio", - "tokio-rustls", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" +dependencies = [ + "atoi", + "base64", + "bitflags 2.4.2", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490" +dependencies = [ + "atoi", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "time", + "tracing", + "url", + "urlencoding", + "uuid", ] [[package]] name = "stringprep" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" +checksum = "bb41d74e231a107a1b4ee36bd1214b11285b77768d2e3824aedafa988fd36ee6" dependencies = [ + "finl_unicode", "unicode-bidi", "unicode-normalization", ] @@ -2231,9 +2509,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.28" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -2242,53 +2520,78 @@ dependencies = [ [[package]] name = "syn-mid" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baa8e7560a164edb1621a55d18a0c59abf49d360f47aa7b821061dd7eea7fac9" +checksum = "fea305d57546cc8cd04feb14b62ec84bf17f50e3f7b12560d7bfa9265f39d9ed" dependencies = [ "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "target-lexicon" -version = "0.12.7" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "tempfile" -version = "3.6.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ - "autocfg", "cfg-if", "fastrand", - "redox_syscall 0.3.5", - "rustix 0.37.26", - "windows-sys 0.48.0", + "rustix", + "windows-sys 0.52.0", ] [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -2303,22 +2606,14 @@ dependencies = [ [[package]] name = "time" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" -dependencies = [ - "libc", - "wasi 0.10.0+wasi-snapshot-preview1", - "winapi", -] - -[[package]] -name = "time" -version = "0.3.22" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ + "deranged", "itoa", + "num-conv", + "powerfmt", "serde", "time-core", "time-macros", @@ -2326,16 +2621,17 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.9" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] @@ -2356,11 +2652,11 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.2" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ - "autocfg", + "backtrace", "bytes", "libc", "mio", @@ -2373,13 +2669,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] @@ -2392,17 +2688,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls", - "tokio", - "webpki", -] - [[package]] name = "tokio-stream" version = "0.1.14" @@ -2416,9 +2701,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.8" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" dependencies = [ "bytes", "futures-core", @@ -2436,11 +2721,11 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2448,20 +2733,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", ] [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -2469,12 +2754,12 @@ dependencies = [ [[package]] name = "tracing-log" -version = "0.1.3" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" dependencies = [ - "lazy_static", "log", + "once_cell", "tracing-core", ] @@ -2490,9 +2775,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "nu-ansi-term", "serde", @@ -2507,27 +2792,27 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -2540,15 +2825,15 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" [[package]] name = "unicode_categories" @@ -2564,21 +2849,27 @@ checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" [[package]] name = "untrusted" -version = "0.7.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8parse" version = "0.2.1" @@ -2587,9 +2878,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.3.4" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" dependencies = [ "getrandom", "serde", @@ -2632,12 +2923,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2646,9 +2931,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2656,24 +2941,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.37" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97" dependencies = [ "cfg-if", "js-sys", @@ -2683,9 +2968,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2693,67 +2978,50 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn 2.0.48", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.87" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" [[package]] name = "web-sys" -version = "0.3.64" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" dependencies = [ "js-sys", "wasm-bindgen", ] -[[package]] -name = "webpki" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0e74f82d49d545ad128049b7e88f6576df2da6b02e9ce565c6f533be576957e" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "webpki-roots" -version = "0.22.6" +version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87" -dependencies = [ - "webpki", -] +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "weezl" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" [[package]] name = "whoami" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c70234412ca409cc04e864e89523cb0fc37f5e1344ebed5a3ebf4192b6b9f68" -dependencies = [ - "wasm-bindgen", - "web-sys", -] +checksum = "22fc3756b8a9133049b26c7f61ab35416c130e8c09b660f5b3958b446f52cc50" [[package]] name = "winapi" @@ -2787,151 +3055,178 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows" -version = "0.48.0" +name = "windows-core" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] name = "windows-sys" -version = "0.45.0" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.42.2", + "windows-targets 0.48.5", ] [[package]] name = "windows-sys" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.52.0", ] [[package]] name = "windows-targets" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" [[package]] name = "windows_aarch64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" [[package]] name = "windows_i686_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" [[package]] name = "windows_i686_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" [[package]] name = "windows_x86_64_gnu" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" [[package]] name = "windows_x86_64_msvc" -version = "0.42.2" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index cc126e8cf..633c9d30d 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "0.10.1" +version = "1.0.0" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" @@ -15,10 +15,10 @@ crate-type = ["lib", "cdylib"] [dependencies] rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} -sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } +sqlx = { version = "0.7.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } serde_json = "1.0.9" anyhow = "1.0.9" -tokio = { version = "1.28.2", features = [ "macros" ] } +tokio = { version = "1.28.2", features = [ "macros", "rt-multi-thread" ] } chrono = "0.4.9" pyo3 = { version = "0.18.3", optional = true, features = ["extension-module", "anyhow"] } pyo3-asyncio = { version = "0.18", features = ["attributes", "tokio-runtime"], optional = true } @@ -26,8 +26,8 @@ neon = { version = "0.10", optional = true, default-features = false, features = itertools = "0.10.5" uuid = {version = "1.3.3", features = ["v4", "serde"] } md5 = "0.7.0" -sea-query = { version = "0.29.1", features = ["attr", "thread-safe", "with-json", "postgres-array"] } -sea-query-binder = { version = "0.4.0", features = ["sqlx-postgres", "with-json", "postgres-array"] } +sea-query = { version = "0.30.7", features = ["attr", "thread-safe", "with-json", "with-uuid", "postgres-array"] } +sea-query-binder = { version = "0.5.0", features = ["sqlx-postgres", "with-json", "with-uuid", "postgres-array"] } regex = "1.8.4" reqwest = { version = "0.11", features = ["json", "native-tls-vendored"] } async-trait = "0.1.71" diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index f017a04db..7c989b3a4 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -4,6 +4,7 @@ use std::io::Write; const ADDITIONAL_DEFAULTS_FOR_PYTHON: &[u8] = br#" def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def SingleFieldPipeline(name: str, model: Optional[Model] = None, splitter: Optional[Splitter] = None, parameters: Optional[Json] = Any) -> Pipeline async def migrate() -> None Json = Any @@ -14,6 +15,7 @@ GeneralJsonAsyncIterator = Any const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" export function init_logger(level?: string, format?: string): void; +export function newSingleFieldPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; export function migrate(): Promise; export type Json = any; @@ -25,7 +27,7 @@ export function newCollection(name: string, database_url?: string): Collection; export function newModel(name?: string, source?: string, parameters?: Json): Model; export function newSplitter(name?: string, parameters?: Json): Splitter; export function newBuiltins(database_url?: string): Builtins; -export function newPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; +export function newPipeline(name: string, schema?: Json): Pipeline; export function newTransformerPipeline(task: string, model?: string, args?: Json, database_url?: string): TransformerPipeline; export function newOpenSourceAI(database_url?: string): OpenSourceAI; "#; @@ -37,7 +39,6 @@ fn main() { remove_file(&path).ok(); let mut file = OpenOptions::new() .create(true) - .write(true) .append(true) .open(path) .unwrap(); @@ -51,7 +52,6 @@ fn main() { remove_file(&path).ok(); let mut file = OpenOptions::new() .create(true) - .write(true) .append(true) .open(path) .unwrap(); diff --git a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js index f70bf26b4..0ab69decb 100644 --- a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js @@ -1,19 +1,19 @@ const pgml = require("pgml"); require("dotenv").config(); - const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_eqa_collection_2"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_eqa_pipeline_1", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -29,33 +29,31 @@ const main = async () => { ]; await collection.upsert_documents(documents); - const query = "What is the best tool for machine learning?"; - // Perform vector search - const queryResults = await collection - .query() - .vector_recall(query, pipeline) - .limit(1) - .fetch_all(); - - // Construct context from results - const context = queryResults - .map((result) => { - return result[1]; - }) - .join("\n"); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log("The results"); + console.log(queryResults); + + const context = queryResults.map((result) => result["chunk"]).join("\n\n"); // Query for answer const builtins = pgml.newBuiltins(); const answer = await builtins.transform("question-answering", [ JSON.stringify({ question: query, context: context }), ]); + console.log("The answer"); + console.log(answer); // Archive the collection await collection.archive(); - return answer; }; -main().then((results) => { - console.log("Question answer: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/question_answering.js b/pgml-sdks/pgml/javascript/examples/question_answering.js index f8f7f83f5..0d4e08844 100644 --- a/pgml-sdks/pgml/javascript/examples/question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/question_answering.js @@ -3,16 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_qa_collection"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_qa_pipeline", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -29,27 +30,19 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall("What is the best tool for machine learning?", pipeline) - .limit(1) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js index 1e4c22164..bb265cc6a 100644 --- a/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js +++ b/pgml-sdks/pgml/javascript/examples/question_answering_instructor.js @@ -3,18 +3,20 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_qai_collection"); + const collection = pgml.newCollection("qa_pipeline"); // Add a pipeline - const model = pgml.newModel("hkunlp/instructor-base", "pgml", { - instruction: "Represent the Wikipedia document for retrieval: ", + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "hkunlp/instructor-base", + parameters: { + instruction: "Represent the Wikipedia document for retrieval: " + } + }, + }, }); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_qai_pipeline", - model, - splitter, - ); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -31,30 +33,25 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall("What is the best tool for machine learning?", pipeline, { - instruction: - "Represent the Wikipedia question for retrieving supporting documents: ", - }) - .limit(1) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { + query: query, + parameters: { + instruction: + "Represent the Wikipedia question for retrieving supporting documents: ", + } + } + } + }, limit: 1 + }, pipeline); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/semantic_search.js b/pgml-sdks/pgml/javascript/examples/semantic_search.js index b1458e889..a40970768 100644 --- a/pgml-sdks/pgml/javascript/examples/semantic_search.js +++ b/pgml-sdks/pgml/javascript/examples/semantic_search.js @@ -3,12 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_collection"); + const collection = pgml.newCollection("semantic_search_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline("my_javascript_pipeline", model, splitter); + const pipeline = pgml.newPipeline("semantic_search_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -25,30 +30,20 @@ const main = async () => { await collection.upsert_documents(documents); // Perform vector search - const queryResults = await collection - .query() - .vector_recall( - "Some user query that will match document one first", - pipeline, - ) - .limit(2) - .fetch_all(); - - // Convert the results to an array of objects - const results = queryResults.map((result) => { - const [similarity, text, metadata] = result; - return { - similarity, - text, - metadata, - }; - }); + const query = "Something that will match document one first"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 2 + }, pipeline); + console.log("The results"); + console.log(queryResults); // Archive the collection await collection.archive(); - return results; }; -main().then((results) => { - console.log("Vector search Results: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js index f779cde60..5afeba45c 100644 --- a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js @@ -3,16 +3,17 @@ require("dotenv").config(); const main = async () => { // Initialize the collection - const collection = pgml.newCollection("my_javascript_sqa_collection"); + const collection = pgml.newCollection("qa_collection"); // Add a pipeline - const model = pgml.newModel(); - const splitter = pgml.newSplitter(); - const pipeline = pgml.newPipeline( - "my_javascript_sqa_pipeline", - model, - splitter, - ); + const pipeline = pgml.newPipeline("qa_pipeline", { + text: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "intfloat/e5-small", + }, + }, + }); await collection.add_pipeline(pipeline); // Upsert documents, these documents are automatically split into chunks and embedded by our pipeline @@ -28,21 +29,20 @@ const main = async () => { ]; await collection.upsert_documents(documents); - const query = "What is the best tool for machine learning?"; - // Perform vector search - const queryResults = await collection - .query() - .vector_recall(query, pipeline) - .limit(1) - .fetch_all(); - - // Construct context from results - const context = queryResults - .map((result) => { - return result[1]; - }) - .join("\n"); + const query = "What is the best tool for building machine learning applications?"; + const queryResults = await collection.vector_search( + { + query: { + fields: { + text: { query: query } + } + }, limit: 1 + }, pipeline); + console.log("The results"); + console.log(queryResults); + + const context = queryResults.map((result) => result["chunk"]).join("\n\n"); // Query for summarization const builtins = pgml.newBuiltins(); @@ -50,12 +50,11 @@ const main = async () => { { task: "summarization", model: "sshleifer/distilbart-cnn-12-6" }, [context], ); + console.log("The summary"); + console.log(answer); // Archive the collection await collection.archive(); - return answer; }; -main().then((results) => { - console.log("Question summary: \n", results); -}); +main().then(() => console.log("Done!")); diff --git a/pgml-sdks/pgml/javascript/package-lock.json b/pgml-sdks/pgml/javascript/package-lock.json index 9ab5f611e..e3035d038 100644 --- a/pgml-sdks/pgml/javascript/package-lock.json +++ b/pgml-sdks/pgml/javascript/package-lock.json @@ -1,13 +1,16 @@ { "name": "pgml", - "version": "0.9.6", + "version": "1.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "pgml", - "version": "0.9.6", + "version": "1.0.0", "license": "MIT", + "dependencies": { + "dotenv": "^16.4.4" + }, "devDependencies": { "@types/node": "^20.3.1", "cargo-cp-artifact": "^0.1" @@ -27,6 +30,17 @@ "bin": { "cargo-cp-artifact": "bin/cargo-cp-artifact.js" } + }, + "node_modules/dotenv": { + "version": "16.4.5", + "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.4.5.tgz", + "integrity": "sha512-ZmdL2rui+eB2YwhsWzjInR8LldtZHGDoQ1ugH85ppHKwpUHL7j7rN0Ti9NCnGiQbhaZ11FpR+7ao1dNsmduNUg==", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } } } } diff --git a/pgml-sdks/pgml/javascript/package.json b/pgml-sdks/pgml/javascript/package.json index 9b6502458..a6572d67f 100644 --- a/pgml-sdks/pgml/javascript/package.json +++ b/pgml-sdks/pgml/javascript/package.json @@ -1,6 +1,6 @@ { "name": "pgml", - "version": "0.10.1", + "version": "1.0.0", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", "keywords": [ "postgres", @@ -26,5 +26,8 @@ "devDependencies": { "@types/node": "^20.3.1", "cargo-cp-artifact": "^0.1" + }, + "dependencies": { + "dotenv": "^16.4.4" } } diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index ad0c9cd78..9fa4e4954 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -17,6 +17,8 @@ const generate_dummy_documents = (count: number) => { for (let i = 0; i < count; i++) { docs.push({ id: i, + title: `Test Document ${i}`, + body: `Test body ${i}`, text: `This is a test document: ${i}`, project: "a10", uuid: i * 10, @@ -50,9 +52,14 @@ it("can create splitter", () => { }); it("can create pipeline", () => { + let pipeline = pgml.newPipeline("test_j_p_ccp"); + expect(pipeline).toBeTruthy(); +}); + +it("can create single field pipeline", () => { let model = pgml.newModel(); let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_ccc_0", model, splitter); + let pipeline = pgml.newSingleFieldPipeline("test_j_p_ccsfp", model, splitter); expect(pipeline).toBeTruthy(); }); @@ -62,145 +69,97 @@ it("can create builtins", () => { }); /////////////////////////////////////////////////// -// Test various vector searches /////////////////// +// Test various searches /////////////////// /////////////////////////////////////////////////// -it("can vector search with local embeddings", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswle_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswle_3"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection.vector_search("Here is some query", pipeline); - expect(results).toHaveLength(3); - await collection.archive(); -}); - -it("can vector search with remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswre_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswre_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection.vector_search("Here is some query", pipeline); - expect(results).toHaveLength(3); +it("can search", async () => { + let pipeline = pgml.newPipeline("test_j_p_cs", { + title: { semantic_search: { model: "intfloat/e5-small" } }, + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "text-embedding-ada-002", + source: "openai", + }, + full_text_search: { configuration: "english" }, + }, + }); + let collection = pgml.newCollection("test_j_c_tsc_15") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + let results = await collection.search( + { + query: { + full_text_search: { body: { query: "Test", boost: 1.2 } }, + semantic_search: { + title: { query: "This is a test", boost: 2.0 }, + body: { query: "This is the body test", boost: 1.01 }, + }, + filter: { id: { $gt: 1 } }, + }, + limit: 10 + }, + pipeline, + ); + let ids = results["results"].map((r: any) => r["id"]); + expect(ids).toEqual([5, 4, 3]); await collection.archive(); }); -it("can vector search with query builder", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqb_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqb_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); +/////////////////////////////////////////////////// +// Test various vector searches /////////////////// +/////////////////////////////////////////////////// -it("can vector search with query builder with remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbwre_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbwre_1"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); -it("can vector search with query builder and metadata filtering", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbamf_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbamf_4"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .filter({ - metadata: { - $or: [{ uuid: { $eq: 0 } }, { floating_uuid: { $lt: 2 } }], - project: { $eq: "a10" }, +it("can vector search", async () => { + let pipeline = pgml.newPipeline("test_j_p_cvs_0", { + title: { + semantic_search: { model: "intfloat/e5-small" }, + full_text_search: { configuration: "english" }, + }, + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "text-embedding-ada-002", + source: "openai", }, - }) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(2); - await collection.archive(); -}); - -it("can vector search with query builder and custom hnsfw ef_search value", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_cvswqbachesv_0", model, splitter); - let collection = pgml.newCollection("test_j_c_cvswqbachesv_0"); - await collection.upsert_documents(generate_dummy_documents(3)); - await collection.add_pipeline(pipeline); - let results = await collection - .query() - .vector_recall("Here is some query", pipeline) - .filter({ - hnsw: { - ef_search: 2, + }, + }); + let collection = pgml.newCollection("test_j_c_cvs_4") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + let results = await collection.vector_search( + { + query: { + fields: { + title: { query: "Test document: 2", full_text_filter: "test" }, + body: { query: "Test document: 2" }, + }, + filter: { id: { "$gt": 2 } }, }, - }) - .limit(10) - .fetch_all(); - expect(results).toHaveLength(3); + limit: 5, + }, + pipeline, + ); + let ids = results.map(r => r["document"]["id"]); + expect(ids).toEqual([3, 4, 4, 3]); await collection.archive(); }); -it("can vector search with query builder and custom hnsfw ef_search value and remote embeddings", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); +it("can vector search with query builder", async () => { + let model = pgml.newModel(); let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline( - "test_j_p_cvswqbachesvare_0", - model, - splitter, - ); - let collection = pgml.newCollection("test_j_c_cvswqbachesvare_0"); + let pipeline = pgml.newSingleFieldPipeline("test_j_p_cvswqb_0", model, splitter); + let collection = pgml.newCollection("test_j_c_cvswqb_2"); await collection.upsert_documents(generate_dummy_documents(3)); await collection.add_pipeline(pipeline); let results = await collection .query() .vector_recall("Here is some query", pipeline) - .filter({ - hnsw: { - ef_search: 2, - }, - }) .limit(10) .fetch_all(); - expect(results).toHaveLength(3); - await collection.archive(); -}); - -/////////////////////////////////////////////////// -// Test user output facing functions ////////////// -/////////////////////////////////////////////////// - -it("pipeline to dict", async () => { - let model = pgml.newModel("text-embedding-ada-002", "openai"); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_j_p_ptd_0", model, splitter); - let collection = pgml.newCollection("test_j_c_ptd_2"); - await collection.add_pipeline(pipeline); - let pipeline_dict = await pipeline.to_dict(); - expect(pipeline_dict["name"]).toBe("test_j_p_ptd_0"); + let ids = results.map(r => r[2]["id"]); + expect(ids).toEqual([2, 1, 0]); await collection.archive(); }); @@ -209,60 +168,38 @@ it("pipeline to dict", async () => { /////////////////////////////////////////////////// it("can upsert and get documents", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline("test_p_p_cuagd_0", model, splitter, { - full_text_search: { active: true, configuration: "english" }, - }); let collection = pgml.newCollection("test_p_c_cuagd_1"); - await collection.add_pipeline(pipeline); await collection.upsert_documents(generate_dummy_documents(10)); - let documents = await collection.get_documents(); expect(documents).toHaveLength(10); - documents = await collection.get_documents({ offset: 1, limit: 2, - filter: { metadata: { id: { $gt: 0 } } }, + filter: { id: { $gt: 0 } }, }); expect(documents).toHaveLength(2); expect(documents[0]["document"]["id"]).toBe(2); let last_row_id = documents[1]["row_id"]; - documents = await collection.get_documents({ filter: { - metadata: { id: { $gt: 3 } }, - full_text_search: { configuration: "english", text: "4" }, + id: { $lt: 7 }, }, last_row_id: last_row_id, }); - expect(documents).toHaveLength(1); + expect(documents).toHaveLength(3); expect(documents[0]["document"]["id"]).toBe(4); - await collection.archive(); }); it("can delete documents", async () => { - let model = pgml.newModel(); - let splitter = pgml.newSplitter(); - let pipeline = pgml.newPipeline( - "test_p_p_cdd_0", - model, - splitter, - - { full_text_search: { active: true, configuration: "english" } }, - ); let collection = pgml.newCollection("test_p_c_cdd_2"); - await collection.add_pipeline(pipeline); await collection.upsert_documents(generate_dummy_documents(3)); await collection.delete_documents({ - metadata: { id: { $gte: 0 } }, - full_text_search: { configuration: "english", text: "0" }, + id: { $gte: 2 }, }); let documents = await collection.get_documents(); expect(documents).toHaveLength(2); - expect(documents[0]["document"]["id"]).toBe(1); + expect(documents[0]["document"]["id"]).toBe(0); await collection.archive(); }); @@ -286,13 +223,13 @@ it("can order documents", async () => { it("can transformer pipeline", async () => { const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform(["AI is going to"], {max_new_tokens: 5}); + const it = await t.transform(["AI is going to"], { max_new_tokens: 5 }); expect(it.length).toBeGreaterThan(0) }); it("can transformer pipeline stream", async () => { const t = pgml.newTransformerPipeline("text-generation"); - const it = await t.transform_stream("AI is going to", {max_new_tokens: 5}); + const it = await t.transform_stream("AI is going to", { max_new_tokens: 5 }); let result = await it.next(); let output = []; while (!result.done) { @@ -309,17 +246,17 @@ it("can transformer pipeline stream", async () => { it("can open source ai create", () => { const client = pgml.newOpenSourceAI(); const results = client.chat_completions_create( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -328,17 +265,17 @@ it("can open source ai create", () => { it("can open source ai create async", async () => { const client = pgml.newOpenSourceAI(); const results = await client.chat_completions_create_async( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); expect(results.choices.length).toBeGreaterThan(0); }); @@ -347,17 +284,17 @@ it("can open source ai create async", async () => { it("can open source ai create stream", () => { const client = pgml.newOpenSourceAI(); const it = client.chat_completions_create_stream( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); let result = it.next(); while (!result.done) { @@ -369,17 +306,17 @@ it("can open source ai create stream", () => { it("can open source ai create stream async", async () => { const client = pgml.newOpenSourceAI(); const it = await client.chat_completions_create_stream_async( - "HuggingFaceH4/zephyr-7b-beta", - [ - { - role: "system", - content: "You are a friendly chatbot who always responds in the style of a pirate", - }, - { - role: "user", - content: "How many helicopters can a human eat in one sitting?", - }, - ], + "HuggingFaceH4/zephyr-7b-beta", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], ); let result = await it.next(); while (!result.done) { diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index c7b5b4c08..7c3e14230 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "0.10.1" +version = "1.0.0" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/python/examples/extractive_question_answering.py b/pgml-sdks/pgml/python/examples/extractive_question_answering.py index 21b5f2e67..21a0060f5 100644 --- a/pgml-sdks/pgml/python/examples/extractive_question_answering.py +++ b/pgml-sdks/pgml/python/examples/extractive_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins +from pgml import Collection, Pipeline, Builtins import json from datasets import load_dataset from time import time @@ -14,10 +14,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -36,8 +42,8 @@ async def main(): query = "Who won more than 20 grammy awards?" console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 10}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") @@ -45,8 +51,8 @@ async def main(): console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for answer builtins = Builtins() diff --git a/pgml-sdks/pgml/python/examples/question_answering.py b/pgml-sdks/pgml/python/examples/question_answering.py index 923eebc31..d4b2cc082 100644 --- a/pgml-sdks/pgml/python/examples/question_answering.py +++ b/pgml-sdks/pgml/python/examples/question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -13,10 +13,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -31,12 +37,12 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query - query = "Who won 20 grammy awards?" - console.print("Querying for %s..." % query) + # Query for answer + query = "Who won more than 20 grammy awards?" + console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/examples/question_answering_instructor.py b/pgml-sdks/pgml/python/examples/question_answering_instructor.py index 3ca71e429..ba0069837 100644 --- a/pgml-sdks/pgml/python/examples/question_answering_instructor.py +++ b/pgml-sdks/pgml/python/examples/question_answering_instructor.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -11,15 +11,23 @@ async def main(): console = Console() # Initialize collection - collection = Collection("squad_collection_1") + collection = Collection("squad_collection") - # Create a pipeline using hkunlp/instructor-base - model = Model( - name="hkunlp/instructor-base", - parameters={"instruction": "Represent the Wikipedia document for retrieval: "}, + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + }, + } + }, ) - splitter = Splitter() - pipeline = Pipeline("squad_instruction", model, splitter) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -34,21 +42,25 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query + # Query for answer query = "Who won more than 20 grammy awards?" - console.print("Querying for %s..." % query) + console.print("Querying for context ...") start = time() - results = ( - await collection.query() - .vector_recall( - query, - pipeline, - query_parameters={ - "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + results = await collection.vector_search( + { + "query": { + "fields": { + "text": { + "query": query, + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + }, + } }, - ) - .limit(5) - .fetch_all() + "limit": 5, + }, + pipeline, ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/examples/rag_question_answering.py b/pgml-sdks/pgml/python/examples/rag_question_answering.py index 94db6846c..2558287f6 100644 --- a/pgml-sdks/pgml/python/examples/rag_question_answering.py +++ b/pgml-sdks/pgml/python/examples/rag_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins, OpenSourceAI +from pgml import Collection, Pipeline, OpenSourceAI, init_logger import json from datasets import load_dataset from time import time @@ -7,6 +7,9 @@ import asyncio +init_logger() + + async def main(): load_dotenv() console = Console() @@ -14,10 +17,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -34,22 +43,19 @@ async def main(): # Query for context query = "Who won more than 20 grammy awards?" - - console.print("Question: %s"%query) console.print("Querying for context ...") - start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 10}, pipeline ) end = time() - - #console.print("Query time = %0.3f" % (end - start)) + console.print("\n Results for '%s' " % (query), style="bold") + console.print(results) + console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") - console.print("Context is ready...") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for answer system_prompt = """Use the following pieces of context to answer the question at the end. diff --git a/pgml-sdks/pgml/python/examples/semantic_search.py b/pgml-sdks/pgml/python/examples/semantic_search.py index df861502f..9a4e134e5 100644 --- a/pgml-sdks/pgml/python/examples/semantic_search.py +++ b/pgml-sdks/pgml/python/examples/semantic_search.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline +from pgml import Collection, Pipeline from datasets import load_dataset from time import time from dotenv import load_dotenv @@ -13,17 +13,24 @@ async def main(): # Initialize collection collection = Collection("quora_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("quorav1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "quorav1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) - + # Prep documents for upserting dataset = load_dataset("quora", split="train") questions = [] for record in dataset["questions"]: questions.extend(record["text"]) + # Remove duplicates and add id documents = [] for i, question in enumerate(list(set(questions))): @@ -31,14 +38,14 @@ async def main(): documents.append({"id": i, "text": question}) # Upsert documents - await collection.upsert_documents(documents[:200]) + await collection.upsert_documents(documents[:2000]) # Query query = "What is a good mobile os?" console.print("Querying for %s..." % query) start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py index 3008b31a9..862830277 100644 --- a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py +++ b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py @@ -14,10 +14,16 @@ async def main(): # Initialize collection collection = Collection("squad_collection") - # Create a pipeline using the default model and splitter - model = Model() - splitter = Splitter() - pipeline = Pipeline("squadv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "squadv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -32,12 +38,12 @@ async def main(): # Upsert documents await collection.upsert_documents(documents[:200]) - # Query for context + # Query for answer query = "Who won more than 20 grammy awards?" console.print("Querying for context ...") start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 3}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") @@ -45,8 +51,8 @@ async def main(): console.print("Query time = %0.3f" % (end - start)) # Construct context from results - context = " ".join(results[0][1].strip().split()) - context = context.replace('"', '\\"').replace("'", "''") + chunks = [r["chunk"] for r in results] + context = "\n\n".join(chunks) # Query for summary builtins = Builtins() diff --git a/pgml-sdks/pgml/python/examples/table_question_answering.py b/pgml-sdks/pgml/python/examples/table_question_answering.py index 168a830b2..243380647 100644 --- a/pgml-sdks/pgml/python/examples/table_question_answering.py +++ b/pgml-sdks/pgml/python/examples/table_question_answering.py @@ -15,11 +15,17 @@ async def main(): # Initialize collection collection = Collection("ott_qa_20k_collection") - # Create a pipeline using deepset/all-mpnet-base-v2-table - # A SentenceTransformer model trained specifically for embedding tabular data for retrieval - model = Model(name="deepset/all-mpnet-base-v2-table") - splitter = Splitter() - pipeline = Pipeline("ott_qa_20kv1", model, splitter) + # Create and add pipeline + pipeline = Pipeline( + "ott_qa_20kv1", + { + "text": { + "splitter": {"model": "recursive_character"}, + # A SentenceTransformer model trained specifically for embedding tabular data for retrieval + "semantic_search": {"model": "deepset/all-mpnet-base-v2-table"}, + } + }, + ) await collection.add_pipeline(pipeline) # Prep documents for upserting @@ -46,8 +52,8 @@ async def main(): query = "Which country has the highest GDP in 2020?" console.print("Querying for %s..." % query) start = time() - results = ( - await collection.query().vector_recall(query, pipeline).limit(5).fetch_all() + results = await collection.vector_search( + {"query": {"fields": {"text": {"query": query}}}, "limit": 5}, pipeline ) end = time() console.print("\n Results for '%s' " % (query), style="bold") diff --git a/pgml-sdks/pgml/python/tests/stress_test.py b/pgml-sdks/pgml/python/tests/stress_test.py new file mode 100644 index 000000000..552193690 --- /dev/null +++ b/pgml-sdks/pgml/python/tests/stress_test.py @@ -0,0 +1,110 @@ +import asyncio +import pgml +import time +from datasets import load_dataset + +pgml.init_logger() + +TOTAL_ROWS = 10000 +BATCH_SIZE = 1000 +OFFSET = 0 + +dataset = load_dataset( + "wikipedia", "20220301.en", trust_remote_code=True, split="train" +) + +collection = pgml.Collection("stress-test-collection-3") +pipeline = pgml.Pipeline( + "stress-test-pipeline-1", + { + "text": { + "splitter": { + "model": "recursive_character", + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval: " + }, + }, + }, + }, +) + + +async def upsert_data(): + print(f"\n\nUploading {TOTAL_ROWS} in batches of {BATCH_SIZE}") + total = 0 + batch = [] + tic = time.perf_counter() + for d in dataset: + total += 1 + if total < OFFSET: + continue + batch.append(d) + if len(batch) >= BATCH_SIZE or total >= TOTAL_ROWS: + await collection.upsert_documents(batch, {"batch_size": 1000}) + batch = [] + if total >= TOTAL_ROWS: + break + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def test_document_search(): + print("\n\nDoing document search") + tic = time.perf_counter() + + results = await collection.search( + { + "query": { + "semantic_search": { + "text": { + "query": "What is the best fruit?", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + } + }, + "filter": {"title": {"$ne": "filler"}}, + }, + "limit": 1, + }, + pipeline, + ) + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def test_vector_search(): + print("\n\nDoing vector search") + tic = time.perf_counter() + results = await collection.vector_search( + { + "query": { + "fields": { + "text": { + "query": "What is the best fruit?", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: " + }, + }, + }, + "filter": {"title": {"$ne": "filler"}}, + }, + "limit": 5, + }, + pipeline, + ) + toc = time.perf_counter() + print(f"Done in {toc - tic:0.4f} seconds\n\n") + + +async def main(): + await collection.add_pipeline(pipeline) + await upsert_data() + await test_document_search() + await test_vector_search() + + +asyncio.run(main()) diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 748367867..e4186d4d3 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -14,11 +14,6 @@ #################################################################################### #################################################################################### -DATABASE_URL = os.environ.get("DATABASE_URL") -if DATABASE_URL is None: - print("No DATABASE_URL environment variable found. Please set one") - exit(1) - pgml.init_logger() @@ -28,6 +23,8 @@ def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: dummy_documents.append( { "id": i, + "title": "Test Document {}".format(i), + "body": "Test body {}".format(i), "text": "This is a test document: {}".format(i), "project": "a10", "floating_uuid": i * 1.01, @@ -60,9 +57,14 @@ def test_can_create_splitter(): def test_can_create_pipeline(): + pipeline = pgml.Pipeline("test_p_p_tccp_0", {}) + assert pipeline is not None + + +def test_can_create_single_field_pipeline(): model = pgml.Model() splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tccp_0", model, splitter) + pipeline = pgml.SingleFieldPipeline("test_p_p_tccsfp_0", model, splitter, {}) assert pipeline is not None @@ -72,151 +74,105 @@ def test_can_create_builtins(): ################################################### -## Test various vector searches ################### +## Test searches ################################## ################################################### @pytest.mark.asyncio -async def test_can_vector_search_with_local_embeddings(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvs_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvs_4") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = await collection.vector_search("Here is some query", pipeline) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswre_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswre_3") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = await collection.vector_search("Here is some query", pipeline) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqb_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqb_5") - await collection.upsert_documents(generate_dummy_documents(3)) +async def test_can_search(): + pipeline = pgml.Pipeline( + "test_p_p_tcs_0", + { + "title": {"semantic_search": {"model": "intfloat/e5-small"}}, + "body": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", + }, + "full_text_search": {"configuration": "english"}, + }, + }, + ) + collection = pgml.Collection("test_p_c_tsc_13") await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all() + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.search( + { + "query": { + "full_text_search": {"body": {"query": "Test", "boost": 1.2}}, + "semantic_search": { + "title": {"query": "This is a test", "boost": 2.0}, + "body": {"query": "This is the body test", "boost": 1.01}, + }, + "filter": {"id": {"$gt": 1}}, + }, + "limit": 5, + }, + pipeline, ) - assert len(results) == 3 + ids = [result["id"] for result in results["results"]] + assert ids == [5, 4, 3] await collection.archive() -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_with_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbwre_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbwre_1") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .limit(10) - .fetch_all() - ) - assert len(results) == 3 - await collection.archive() +################################################### +## Test various vector searches ################### +################################################### @pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_metadata_filtering(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbamf_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbamf_2") - await collection.upsert_documents(generate_dummy_documents(3)) +async def test_can_vector_search(): + pipeline = pgml.Pipeline( + "test_p_p_tcvs_0", + { + "title": { + "semantic_search": {"model": "intfloat/e5-small"}, + "full_text_search": {"configuration": "english"}, + }, + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": {"model": "intfloat/e5-small"}, + }, + }, + ) + collection = pgml.Collection("test_p_c_tcvs_3") await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .filter( - { - "metadata": { - "$or": [{"uuid": {"$eq": 0}}, {"floating_uuid": {"$lt": 2}}], - "project": {"$eq": "a10"}, + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.vector_search( + { + "query": { + "fields": { + "title": {"query": "Test document: 2", "full_text_filter": "test"}, + "text": {"query": "Test document: 2"}, }, - } - ) - .limit(10) - .fetch_all() + "filter": {"id": {"$gt": 2}}, + }, + "limit": 5, + }, + pipeline, ) - assert len(results) == 2 + ids = [result["document"]["id"] for result in results] + assert ids == [3, 3, 4, 4] await collection.archive() @pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value(): +async def test_can_vector_search_with_query_builder(): model = pgml.Model() splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbachesv_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbachesv_0") - await collection.upsert_documents(generate_dummy_documents(3)) - await collection.add_pipeline(pipeline) - results = ( - await collection.query() - .vector_recall("Here is some query", pipeline) - .filter({"hnsw": {"ef_search": 2}}) - .limit(10) - .fetch_all() - ) - assert len(results) == 3 - await collection.archive() - - -@pytest.mark.asyncio -async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tcvswqbachesvare_0", model, splitter) - collection = pgml.Collection(name="test_p_c_tcvswqbachesvare_0") + pipeline = pgml.SingleFieldPipeline("test_p_p_tcvswqb_1", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqb_5") await collection.upsert_documents(generate_dummy_documents(3)) await collection.add_pipeline(pipeline) results = ( await collection.query() .vector_recall("Here is some query", pipeline) - .filter({"hnsw": {"ef_search": 2}}) .limit(10) .fetch_all() ) - assert len(results) == 3 - await collection.archive() - - -################################################### -## Test user output facing functions ############## -################################################### - - -@pytest.mark.asyncio -async def test_pipeline_to_dict(): - model = pgml.Model(name="text-embedding-ada-002", source="openai") - splitter = pgml.Splitter() - pipeline = pgml.Pipeline("test_p_p_tptd_1", model, splitter) - collection = pgml.Collection(name="test_p_c_tptd_1") - await collection.add_pipeline(pipeline) - pipeline_dict = await pipeline.to_dict() - assert pipeline_dict["name"] == "test_p_p_tptd_1" - await collection.remove_pipeline(pipeline) + ids = [document["id"] for (_, _, document) in results] + assert ids == [2, 1, 0] await collection.archive() @@ -227,64 +183,38 @@ async def test_pipeline_to_dict(): @pytest.mark.asyncio async def test_upsert_and_get_documents(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline( - "test_p_p_tuagd_0", - model, - splitter, - {"full_text_search": {"active": True, "configuration": "english"}}, - ) - collection = pgml.Collection(name="test_p_c_tuagd_2") - await collection.add_pipeline( - pipeline, - ) + collection = pgml.Collection("test_p_c_tuagd_2") await collection.upsert_documents(generate_dummy_documents(10)) - documents = await collection.get_documents() assert len(documents) == 10 - documents = await collection.get_documents( - {"offset": 1, "limit": 2, "filter": {"metadata": {"id": {"$gt": 0}}}} + {"offset": 1, "limit": 2, "filter": {"id": {"$gt": 0}}} ) assert len(documents) == 2 and documents[0]["document"]["id"] == 2 last_row_id = documents[-1]["row_id"] - documents = await collection.get_documents( { "filter": { - "metadata": {"id": {"$gt": 3}}, - "full_text_search": {"configuration": "english", "text": "4"}, + "id": {"$lt": 7}, }, "last_row_id": last_row_id, } ) - assert len(documents) == 1 and documents[0]["document"]["id"] == 4 - + assert len(documents) == 3 and documents[0]["document"]["id"] == 4 await collection.archive() @pytest.mark.asyncio async def test_delete_documents(): - model = pgml.Model() - splitter = pgml.Splitter() - pipeline = pgml.Pipeline( - "test_p_p_tdd_0", - model, - splitter, - {"full_text_search": {"active": True, "configuration": "english"}}, - ) collection = pgml.Collection("test_p_c_tdd_1") - await collection.add_pipeline(pipeline) await collection.upsert_documents(generate_dummy_documents(3)) await collection.delete_documents( { - "metadata": {"id": {"$gte": 0}}, - "full_text_search": {"configuration": "english", "text": "0"}, + "id": {"$gte": 2}, } ) documents = await collection.get_documents() - assert len(documents) == 2 and documents[0]["document"]["id"] == 1 + assert len(documents) == 2 and documents[0]["document"]["id"] == 0 await collection.archive() @@ -457,30 +387,3 @@ async def test_migrate(): # assert len(x) == 3 # # await collection.archive() - - -################################################### -## Manual tests ################################### -################################################### - - -# async def test_add_pipeline(): -# model = pgml.Model() -# splitter = pgml.Splitter() -# pipeline = pgml.Pipeline("silas_test_p_1", model, splitter) -# collection = pgml.Collection(name="silas_test_c_10") -# await collection.add_pipeline(pipeline) -# -# async def test_upsert_documents(): -# collection = pgml.Collection(name="silas_test_c_9") -# await collection.upsert_documents(generate_dummy_documents(10)) -# -# async def test_vector_search(): -# pipeline = pgml.Pipeline("silas_test_p_1") -# collection = pgml.Collection(name="silas_test_c_9") -# results = await collection.vector_search("Here is some query", pipeline) -# print(results) - -# asyncio.run(test_add_pipeline()) -# asyncio.run(test_upsert_documents()) -# asyncio.run(test_vector_search()) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index e893e64c5..5d43c6a3d 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -3,26 +3,28 @@ use indicatif::MultiProgress; use itertools::Itertools; use regex::Regex; use rust_bridge::{alias, alias_methods}; -use sea_query::{Alias, Expr, JoinType, NullOrdering, Order, PostgresQueryBuilder, Query}; +use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; use serde_json::json; -use sqlx::postgres::PgPool; use sqlx::Executor; use sqlx::PgConnection; use std::borrow::Cow; +use std::collections::HashMap; use std::path::Path; use std::time::SystemTime; +use std::time::UNIX_EPOCH; use tracing::{instrument, warn}; use walkdir::WalkDir; +use crate::debug_sqlx_query; +use crate::filter_builder::FilterBuilder; +use crate::search_query_builder::build_search_query; +use crate::vector_search_query_builder::build_vector_search_query; use crate::{ - filter_builder, get_or_initialize_pool, - model::ModelRuntime, - models, order_by_builder, + get_or_initialize_pool, models, order_by_builder, pipeline::Pipeline, queries, query_builder, query_builder::QueryBuilder, - remote_embeddings::build_remote_embeddings, splitter::Splitter, types::{DateTime, IntoTableNameAndSchema, Json, SIden, TryToNumeric}, utils, @@ -104,7 +106,6 @@ pub struct Collection { pub database_url: Option, pub pipelines_table_name: String, pub documents_table_name: String, - pub transforms_table_name: String, pub chunks_table_name: String, pub documents_tsvectors_table_name: String, pub(crate) database_data: Option, @@ -121,12 +122,16 @@ pub struct Collection { remove_pipeline, enable_pipeline, disable_pipeline, + search, + add_search_event, vector_search, query, exists, archive, upsert_directory, - upsert_file + upsert_file, + generate_er_diagram, + get_pipeline_status )] impl Collection { /// Creates a new [Collection] @@ -143,24 +148,30 @@ impl Collection { /// use pgml::Collection; /// let collection = Collection::new("my_collection", None); /// ``` - pub fn new(name: &str, database_url: Option) -> Self { + pub fn new(name: &str, database_url: Option) -> anyhow::Result { + if !name + .chars() + .all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_') + { + anyhow::bail!( + "Name must only consist of letters, numebers, white space, and '-' or '_'" + ) + } let ( pipelines_table_name, documents_table_name, - transforms_table_name, chunks_table_name, documents_tsvectors_table_name, ) = Self::generate_table_names(name); - Self { + Ok(Self { name: name.to_string(), database_url, pipelines_table_name, documents_table_name, - transforms_table_name, chunks_table_name, documents_tsvectors_table_name, database_data: None, - } + }) } #[instrument(skip(self))] @@ -233,16 +244,14 @@ impl Collection { }, }; + // Splitters table is not unique to a collection or pipeline. It exists in the pgml schema Splitter::create_splitters_table(&mut transaction).await?; + self.create_documents_table(&mut transaction).await?; Pipeline::create_pipelines_table( &collection_database_data.project_info, &mut transaction, ) .await?; - self.create_documents_table(&mut transaction).await?; - self.create_chunks_table(&mut transaction).await?; - self.create_documents_tsvectors_table(&mut transaction) - .await?; transaction.commit().await?; Some(collection_database_data) @@ -252,167 +261,105 @@ impl Collection { } /// Adds a new [Pipeline] to the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to add. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline, Model, Splitter}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let model = Model::new(None, None, None); - /// let splitter = Splitter::new(None, None); - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let mut collection = Collection::new("my_collection", None); - /// collection.add_pipeline(&mut pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn add_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Create collection if it does not exists + // 2. Create the pipeline if it does not exist and add it to the collection.pipelines table with ACTIVE = TRUE + // 3. Sync the pipeline - this will delete all previous chunks, embeddings, and tsvectors self.verify_in_database(false).await?; - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); - let mp = MultiProgress::new(); - mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; - pipeline.execute(&None, mp).await?; - eprintln!("Done Syncing {}\n", pipeline.name); + let project_info = &self + .database_data + .as_ref() + .context("Database data must be set to add a pipeline to a collection")? + .project_info; + + // Let's check if we already have it enabled + let pool = get_or_initialize_pool(&self.database_url).await?; + let pipelines_table_name = format!("{}.pipelines", project_info.name); + let exists: bool = sqlx::query_scalar(&query_builder!( + "SELECT EXISTS (SELECT id FROM %s WHERE name = $1 AND active = TRUE)", + pipelines_table_name + )) + .bind(&pipeline.name) + .fetch_one(&pool) + .await?; + + if exists { + warn!("Pipeline {} already exists not adding", pipeline.name); + } else { + // We want to intentially throw an error if they have already added this pipeline + // as we don't want to casually resync + pipeline + .verify_in_database(project_info, true, &pool) + .await?; + + let mp = MultiProgress::new(); + mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; + pipeline + .resync(project_info, pool.acquire().await?.as_mut()) + .await?; + mp.println(format!("Done Syncing {}\n", pipeline.name))?; + } Ok(()) } /// Removes a [Pipeline] from the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to remove. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let mut collection = Collection::new("my_collection", None); - /// collection.remove_pipeline(&mut pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn remove_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&self.database_url).await?; + pub async fn remove_pipeline(&mut self, pipeline: &Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Create collection if it does not exist + // 2. Begin a transaction + // 3. Drop the collection_pipeline schema + // 4. Delete the pipeline from the collection.pipelines table + // 5. Commit the transaction self.verify_in_database(false).await?; - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); - pipeline.verify_in_database(false).await?; - - let database_data = pipeline - .database_data - .as_ref() - .context("Pipeline must be verified to remove it")?; - - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - - let parameters = pipeline - .parameters - .as_ref() - .context("Pipeline must be verified to remove it")?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + let pipeline_schema = format!("{}_{}", project_info.name, pipeline.name); let mut transaction = pool.begin().await?; - - // Need to delete from chunks table only if no other pipelines use the same splitter - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE splitter_id = $1 AND NOT EXISTS (SELECT 1 FROM %s WHERE splitter_id = $1 AND id != $2)", - self.chunks_table_name, - self.pipelines_table_name - )) - .bind(database_data.splitter_id) - .bind(database_data.id) - .execute(&mut *transaction) + transaction + .execute(query_builder!("DROP SCHEMA IF EXISTS %s CASCADE", pipeline_schema).as_str()) .await?; - - // Drop the embeddings table - sqlx::query(&query_builder!( - "DROP TABLE IF EXISTS %s", - embeddings_table_name - )) - .execute(&mut *transaction) - .await?; - - // Need to delete from the tsvectors table only if no other pipelines use the - // same tsvector configuration - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE configuration = $1 AND NOT EXISTS (SELECT 1 FROM %s WHERE parameters->'full_text_search'->>'configuration' = $1 AND id != $2)", - self.documents_tsvectors_table_name, - self.pipelines_table_name)) - .bind(parameters["full_text_search"]["configuration"].as_str()) - .bind(database_data.id) - .execute(&mut *transaction) - .await?; - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE id = $1", + "DELETE FROM %s WHERE name = $1", self.pipelines_table_name )) - .bind(database_data.id) + .bind(&pipeline.name) .execute(&mut *transaction) .await?; - transaction.commit().await?; Ok(()) } /// Enables a [Pipeline] on the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to remove. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let collection = Collection::new("my_collection", None); - /// collection.enable_pipeline(&pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn enable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { + pub async fn enable_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Set ACTIVE = TRUE for the pipeline in collection.pipelines + // 2. Resync the pipeline + // TODO: Review this pattern + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; sqlx::query(&query_builder!( "UPDATE %s SET active = TRUE WHERE name = $1", self.pipelines_table_name )) .bind(&pipeline.name) - .execute(&get_or_initialize_pool(&self.database_url).await?) + .execute(&pool) .await?; - Ok(()) + pipeline + .resync(project_info, pool.acquire().await?.as_mut()) + .await } /// Disables a [Pipeline] on the [Collection] - /// - /// # Arguments - /// - /// * `pipeline` - The [Pipeline] to remove. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let collection = Collection::new("my_collection", None); - /// collection.disable_pipeline(&pipeline).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn disable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { + // The flow for this function: + // 1. Set ACTIVE = FALSE for the pipeline in collection.pipelines sqlx::query(&query_builder!( "UPDATE %s SET active = FALSE WHERE name = $1", self.pipelines_table_name @@ -429,110 +376,13 @@ impl Collection { query_builder!(queries::CREATE_DOCUMENTS_TABLE, self.documents_table_name).as_str(), ) .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "created_at_index", - self.documents_table_name, - "created_at" - ) - .as_str(), - ) - .await?; conn.execute( query_builder!( queries::CREATE_INDEX_USING_GIN, "", - "metadata_index", + "documents_document_index", self.documents_table_name, - "metadata jsonb_path_ops" - ) - .as_str(), - ) - .await?; - Ok(()) - } - - #[instrument(skip(self, conn))] - async fn create_chunks_table(&mut self, conn: &mut PgConnection) -> anyhow::Result<()> { - conn.execute( - query_builder!( - queries::CREATE_CHUNKS_TABLE, - self.chunks_table_name, - self.documents_table_name - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "created_at_index", - self.chunks_table_name, - "created_at" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "document_id_index", - self.chunks_table_name, - "document_id" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "splitter_id_index", - self.chunks_table_name, - "splitter_id" - ) - .as_str(), - ) - .await?; - Ok(()) - } - - #[instrument(skip(self, conn))] - async fn create_documents_tsvectors_table( - &mut self, - conn: &mut PgConnection, - ) -> anyhow::Result<()> { - conn.execute( - query_builder!( - queries::CREATE_DOCUMENTS_TSVECTORS_TABLE, - self.documents_tsvectors_table_name, - self.documents_table_name - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX, - "", - "configuration_index", - self.documents_tsvectors_table_name, - "configuration" - ) - .as_str(), - ) - .await?; - conn.execute( - query_builder!( - queries::CREATE_INDEX_USING_GIN, - "", - "tsvector_index", - self.documents_tsvectors_table_name, - "ts" + "document jsonb_path_ops" ) .as_str(), ) @@ -541,164 +391,178 @@ impl Collection { } /// Upserts documents into the database - /// - /// # Arguments - /// - /// * `documents` - A vector of documents to upsert - /// * `strict` - Whether to throw an error if keys: `id` or `text` are missing from any documents - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = vec![ - /// serde_json::json!({"id": 1, "text": "hello world"}).into(), - /// serde_json::json!({"id": 2, "text": "hello world"}).into(), - /// ]; - /// collection.upsert_documents(documents, None).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self, documents))] pub async fn upsert_documents( &mut self, documents: Vec, args: Option, ) -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&self.database_url).await?; + // The flow for this function + // 1. Create the collection if it does not exist + // 2. Get all pipelines where ACTIVE = TRUE + // -> Foreach pipeline get the parsed schema + // 4. Foreach n documents + // -> Begin a transaction returning the old document if it existed + // -> Insert the document + // -> Foreach pipeline check if we need to resync the document and if so sync the document + // -> Commit the transaction self.verify_in_database(false).await?; + let mut pipelines = self.get_pipelines().await?; + + let pool = get_or_initialize_pool(&self.database_url).await?; + + let mut parsed_schemas = vec![]; + let project_info = &self.database_data.as_ref().unwrap().project_info; + for pipeline in &mut pipelines { + let parsed_schema = pipeline + .get_parsed_schema(project_info, &pool) + .await + .expect("Error getting parsed schema for pipeline"); + parsed_schemas.push(parsed_schema); + } + let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect(); let args = args.unwrap_or_default(); + let args = args.as_object().context("args must be a JSON object")?; let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); - let documents: anyhow::Result> = documents - .into_iter() - .map(|mut document| { - let document = document - .as_object_mut() - .context("Documents must be a vector of objects")?; - - // We don't want the text included in the document metadata, but everything else - // should be in there - let text = document.remove("text").map(|t| { - t.as_str() - .expect("`text` must be a string in document") - .to_string() - }); - let metadata = serde_json::to_value(&document)?.into(); + let query = if args + .get("merge") + .map(|v| v.as_bool().unwrap_or(false)) + .unwrap_or(false) + { + query_builder!( + queries::UPSERT_DOCUMENT_AND_MERGE_METADATA, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + } else { + query_builder!( + queries::UPSERT_DOCUMENT, + self.documents_table_name, + self.documents_table_name, + self.documents_table_name + ) + }; + + let batch_size = args + .get("batch_size") + .map(TryToNumeric::try_to_u64) + .unwrap_or(Ok(100))?; + + for batch in documents.chunks(batch_size as usize) { + let mut transaction = pool.begin().await?; + + let mut query_values = String::new(); + let mut binding_parameter_counter = 1; + for _ in 0..batch.len() { + query_values = format!( + "{query_values}, (${}, ${}, ${})", + binding_parameter_counter, + binding_parameter_counter + 1, + binding_parameter_counter + 2 + ); + binding_parameter_counter += 3; + } + + let query = query.replace( + "{values_parameters}", + &query_values.chars().skip(1).collect::(), + ); + let query = query.replace( + "{binding_parameter}", + &format!("${binding_parameter_counter}"), + ); + let mut query = sqlx::query_as(&query); + + let mut source_uuids = vec![]; + for document in batch { let id = document .get("id") .context("`id` must be a key in document")? .to_string(); let md5_digest = md5::compute(id.as_bytes()); let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + source_uuids.push(source_uuid); - Ok((source_uuid, text, metadata)) - }) - .collect(); - - // We could continue chaining the above iterators but types become super annoying to - // deal with, especially because we are dealing with async functions. This is much easier to read - // Also, we may want to use a variant of chunks that is owned, I'm not 100% sure of what - // cloning happens when passing values into sqlx bind. itertools variants will not work as - // it is not thread safe and pyo3 will get upset - let mut document_ids = Vec::new(); - for chunk in documents?.chunks(10) { - // Need to make it a vec to partition it and must include explicit typing here - let mut chunk: Vec<&(uuid::Uuid, Option, Json)> = chunk.iter().collect(); - - // Split the chunk into two groups, one with text, and one with just metadata - let split_index = itertools::partition(&mut chunk, |(_, text, _)| text.is_some()); - let (text_chunk, metadata_chunk) = chunk.split_at(split_index); - - // Start the transaction - let mut transaction = pool.begin().await?; + let start = SystemTime::now(); + let timestamp = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); - if !metadata_chunk.is_empty() { - // Update the metadata - // Merge the metadata if the user has specified to do so otherwise replace it - if args["metadata"]["merge"].as_bool().unwrap_or(false) { - sqlx::query(query_builder!( - "UPDATE %s d SET metadata = d.metadata || v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", - self.documents_table_name - ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) - .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) - .execute(&mut *transaction).await?; - } else { - sqlx::query(query_builder!( - "UPDATE %s d SET metadata = v.metadata FROM (SELECT UNNEST($1) source_uuid, UNNEST($2) metadata) v WHERE d.source_uuid = v.source_uuid", - self.documents_table_name - ).as_str()).bind(metadata_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()) - .bind(metadata_chunk.iter().map(|(_, _, metadata)| metadata.0.clone()).collect::>()) - .execute(&mut *transaction).await?; - } + let versions: HashMap = document + .as_object() + .context("document must be an object")? + .iter() + .try_fold(HashMap::new(), |mut acc, (key, value)| { + let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); + let md5_digest = format!("{md5_digest:x}"); + acc.insert( + key.to_owned(), + serde_json::json!({ + "last_updated": timestamp, + "md5": md5_digest + }), + ); + anyhow::Ok(acc) + })?; + let versions = serde_json::to_value(versions)?; + + query = query.bind(source_uuid).bind(document).bind(versions); } - if !text_chunk.is_empty() { - // First delete any documents that already have the same UUID as documents in - // text_chunk, then insert the new ones. - // We are essentially upserting in two steps - sqlx::query(&query_builder!( - "DELETE FROM %s WHERE source_uuid IN (SELECT source_uuid FROM %s WHERE source_uuid = ANY($1::uuid[]))", - self.documents_table_name, - self.documents_table_name - )). - bind(&text_chunk.iter().map(|(source_uuid, _, _)| *source_uuid).collect::>()). - execute(&mut *transaction).await?; - let query_string_values = (0..text_chunk.len()) - .map(|i| format!("(${}, ${}, ${})", i * 3 + 1, i * 3 + 2, i * 3 + 3)) - .collect::>() - .join(","); - let query_string = format!( - "INSERT INTO %s (source_uuid, text, metadata) VALUES {} ON CONFLICT (source_uuid) DO UPDATE SET text = $2, metadata = $3 RETURNING id", - query_string_values - ); - let query = query_builder!(query_string, self.documents_table_name); - let mut query = sqlx::query_scalar(&query); - for (source_uuid, text, metadata) in text_chunk.iter() { - query = query.bind(source_uuid).bind(text).bind(metadata); + let results: Vec<(i64, Option)> = query + .bind(source_uuids) + .fetch_all(&mut *transaction) + .await?; + + let dp: Vec<(i64, Json, Option)> = results + .into_iter() + .zip(batch) + .map(|((id, previous_document), document)| { + (id, document.to_owned(), previous_document) + }) + .collect(); + + for (pipeline, parsed_schema) in &mut pipelines { + let ids_to_run_on: Vec = dp + .iter() + .filter(|(_, document, previous_document)| match previous_document { + Some(previous_document) => parsed_schema + .iter() + .any(|(key, _)| document[key] != previous_document[key]), + None => true, + }) + .map(|(document_id, _, _)| *document_id) + .collect(); + if !ids_to_run_on.is_empty() { + pipeline + .sync_documents(ids_to_run_on, project_info, &mut transaction) + .await + .expect("Failed to execute pipeline"); } - let ids: Vec = query.fetch_all(&mut *transaction).await?; - document_ids.extend(ids); - progress_bar.inc(chunk.len() as u64); } transaction.commit().await?; + progress_bar.inc(batch_size); } + progress_bar.println("Done Upserting Documents\n"); progress_bar.finish(); - eprintln!("Done Upserting Documents\n"); - - self.sync_pipelines(Some(document_ids)).await?; Ok(()) } /// Gets the documents on a [Collection] - /// - /// # Arguments - /// - /// * `args` - The filters and options to apply to the query - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.get_documents(None).await?; - /// Ok(()) - /// } #[instrument(skip(self))] pub async fn get_documents(&self, args: Option) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let mut args = args.unwrap_or_default().0; + let mut args = args.unwrap_or_default(); let args = args.as_object_mut().context("args must be an object")?; // Get limit or set it to 1000 @@ -718,7 +582,7 @@ impl Collection { if let Some(order_by) = args.remove("order_by") { let order_by_builder = - order_by_builder::OrderByBuilder::new(order_by, "documents", "metadata").build()?; + order_by_builder::OrderByBuilder::new(order_by, "documents", "document").build()?; for (order_by, order) in order_by_builder { query.order_by_expr_with_nulls(order_by, order, NullOrdering::Last); } @@ -738,53 +602,9 @@ impl Collection { query.offset(offset); } - if let Some(mut filter) = args.remove("filter") { - let filter = filter - .as_object_mut() - .context("filter must be a Json object")?; - - if let Some(f) = filter.remove("metadata") { - query.cond_where( - filter_builder::FilterBuilder::new(f, "documents", "metadata").build(), - ); - } - if let Some(f) = filter.remove("full_text_search") { - let f = f - .as_object() - .context("Full text filter must be a Json object")?; - let configuration = f - .get("configuration") - .context("In full_text_search `configuration` is required")? - .as_str() - .context("In full_text_search `configuration` must be a string")?; - let filter_text = f - .get("text") - .context("In full_text_search `text` is required")? - .as_str() - .context("In full_text_search `text` must be a string")?; - query - .join_as( - JoinType::InnerJoin, - self.documents_tsvectors_table_name.to_table_tuple(), - Alias::new("documents_tsvectors"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), - ) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )); - } + if let Some(filter) = args.remove("filter") { + let filter = FilterBuilder::new(filter, "documents", "document").build()?; + query.cond_where(filter); } let (sql, values) = query.build_sqlx(PostgresQueryBuilder); @@ -797,83 +617,15 @@ impl Collection { } /// Deletes documents in a [Collection] - /// - /// # Arguments - /// - /// * `filter` - The filters to apply - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let documents = collection.delete_documents(serde_json::json!({ - /// "metadata": { - /// "id": { - /// "eq": 1 - /// } - /// } - /// }).into()).await?; - /// Ok(()) - /// } #[instrument(skip(self))] - pub async fn delete_documents(&self, mut filter: Json) -> anyhow::Result<()> { + pub async fn delete_documents(&self, filter: Json) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; let mut query = Query::delete(); query.from_table(self.documents_table_name.to_table_tuple()); - let filter = filter - .as_object_mut() - .context("filter must be a Json object")?; - - if let Some(f) = filter.remove("metadata") { - query - .cond_where(filter_builder::FilterBuilder::new(f, "documents", "metadata").build()); - } - - if let Some(mut f) = filter.remove("full_text_search") { - let f = f - .as_object_mut() - .context("Full text filter must be a Json object")?; - let configuration = f - .get("configuration") - .context("In full_text_search `configuration` is required")? - .as_str() - .context("In full_text_search `configuration` must be a string")?; - let filter_text = f - .get("text") - .context("In full_text_search `text` is required")? - .as_str() - .context("In full_text_search `text` must be a string")?; - let mut inner_select_query = Query::select(); - inner_select_query - .from_as( - self.documents_tsvectors_table_name.to_table_tuple(), - SIden::Str("documents_tsvectors"), - ) - .column(SIden::Str("document_id")) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ); - query.and_where( - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .in_subquery(inner_select_query), - ); - } + let filter = FilterBuilder::new(filter.0, "documents", "document").build()?; + query.cond_where(filter); let (sql, values) = query.build_sqlx(PostgresQueryBuilder); sqlx::query_with(&sql, values).fetch_all(&pool).await?; @@ -881,198 +633,174 @@ impl Collection { } #[instrument(skip(self))] - pub(crate) async fn sync_pipelines( - &mut self, - document_ids: Option>, - ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pipelines = self.get_pipelines().await?; - if !pipelines.is_empty() { - let mp = MultiProgress::new(); - mp.println("Syncing Pipelines...")?; - use futures::stream::StreamExt; - futures::stream::iter(pipelines) - // Need this map to get around moving the document_ids and mp - .map(|pipeline| (pipeline, document_ids.clone(), mp.clone())) - .for_each_concurrent(10, |(mut pipeline, document_ids, mp)| async move { - pipeline - .execute(&document_ids, mp) - .await - .expect("Failed to execute pipeline"); - }) - .await; - eprintln!("Done Syncing Pipelines\n"); + pub async fn search(&mut self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; + let results: Result<(Json,), _> = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) + .await; + + match results { + Ok(r) => Ok(r.0), + Err(e) => match e.as_database_error() { + Some(d) => { + if d.code() == Some(Cow::from("XX000")) { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; + let (built_query, values) = + build_search_query(self, query, pipeline).await?; + let results: (Json,) = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) + .await?; + Ok(results.0) + } else { + Err(anyhow::anyhow!(e)) + } + } + None => Err(anyhow::anyhow!(e)), + }, } + } + + #[instrument(skip(self))] + pub async fn search_local(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; + let results: (Json,) = sqlx::query_as_with(&built_query, values) + .fetch_one(&pool) + .await?; + Ok(results.0) + } + + #[instrument(skip(self))] + pub async fn add_search_event( + &self, + search_id: i64, + search_result: i64, + event: Json, + pipeline: &Pipeline, + ) -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&self.database_url).await?; + let search_events_table = format!("{}_{}.search_events", self.name, pipeline.name); + let search_results_table = format!("{}_{}.search_results", self.name, pipeline.name); + + let query = query_builder!( + queries::INSERT_SEARCH_EVENT, + search_events_table, + search_results_table + ); + debug_sqlx_query!( + INSERT_SEARCH_EVENT, + query, + search_id, + search_result, + event.0 + ); + sqlx::query(&query) + .bind(search_id) + .bind(search_result) + .bind(event.0) + .execute(&pool) + .await?; Ok(()) } /// Performs vector search on the [Collection] - /// - /// # Arguments - /// - /// * `query` - The query to search for - /// * `pipeline` - The [Pipeline] used for the search - /// * `query_paramaters` - The query parameters passed to the model for search - /// * `top_k` - How many results to limit on. - /// - /// # Example - /// - /// ``` - /// use pgml::{Collection, Pipeline}; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = Pipeline::new("my_pipeline", None, None, None); - /// let results = collection.vector_search("Query", &mut pipeline, None, None).await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] #[allow(clippy::type_complexity)] pub async fn vector_search( &mut self, - query: &str, + query: Json, pipeline: &mut Pipeline, - query_parameters: Option, - top_k: Option, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; - let query_parameters = query_parameters.unwrap_or_default(); - let top_k = top_k.unwrap_or(5); - - // With this system, we only do the wrong type of vector search once - let runtime = if pipeline.model.is_some() { - pipeline.model.as_ref().unwrap().runtime - } else { - ModelRuntime::Python - }; - match runtime { - ModelRuntime::Python => { - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - - let result = sqlx::query_as(&query_builder!( - queries::EMBED_AND_VECTOR_SEARCH, - self.pipelines_table_name, - embeddings_table_name, - self.chunks_table_name, - self.documents_table_name - )) - .bind(&pipeline.name) - .bind(query) - .bind(&query_parameters) - .bind(top_k) + let (built_query, values) = + build_vector_search_query(query.clone(), self, pipeline).await?; + let results: Result, _> = + sqlx::query_as_with(&built_query, values) .fetch_all(&pool) .await; - - match result { - Ok(r) => Ok(r), - Err(e) => match e.as_database_error() { - Some(d) => { - if d.code() == Some(Cow::from("XX000")) { - self.vector_search_with_remote_embeddings( - query, - pipeline, - query_parameters, - top_k, - &pool, - ) - .await - } else { - Err(anyhow::anyhow!(e)) - } - } - None => Err(anyhow::anyhow!(e)), - }, + match results { + Ok(r) => Ok(r + .into_iter() + .map(|v| { + serde_json::json!({ + "document": v.0, + "chunk": v.1, + "score": v.2 + }) + .into() + }) + .collect()), + Err(e) => match e.as_database_error() { + Some(d) => { + if d.code() == Some(Cow::from("XX000")) { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; + let (built_query, values) = + build_vector_search_query(query, self, pipeline).await?; + let results: Vec<(Json, String, f64)> = + sqlx::query_as_with(&built_query, values) + .fetch_all(&pool) + .await?; + Ok(results + .into_iter() + .map(|v| { + serde_json::json!({ + "document": v.0, + "chunk": v.1, + "score": v.2 + }) + .into() + }) + .collect()) + } else { + Err(anyhow::anyhow!(e)) + } } - } - _ => { - self.vector_search_with_remote_embeddings( - query, - pipeline, - query_parameters, - top_k, - &pool, - ) - .await - } + None => Err(anyhow::anyhow!(e)), + }, } - .map(|r| { - r.into_iter() - .map(|(score, id, metadata)| (1. - score, id, metadata)) - .collect() - }) - } - - #[instrument(skip(self, pool))] - #[allow(clippy::type_complexity)] - async fn vector_search_with_remote_embeddings( - &mut self, - query: &str, - pipeline: &mut Pipeline, - query_parameters: Json, - top_k: i64, - pool: &PgPool, - ) -> anyhow::Result> { - self.verify_in_database(false).await?; - - // Have to set the project info before we can get and set the model - pipeline.set_project_info( - self.database_data - .as_ref() - .context( - "Collection must be verified to perform vector search with remote embeddings", - )? - .project_info - .clone(), - ); - // Verify to get and set the model if we don't have it set on the pipeline yet - pipeline.verify_in_database(false).await?; - let model = pipeline - .model - .as_ref() - .context("Pipeline must be verified to perform vector search with remote embeddings")?; - - // We need to make sure we are not mutably and immutably borrowing the same things - let embedding = { - let remote_embeddings = - build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; - let mut embeddings = remote_embeddings.embed(vec![query.to_string()]).await?; - std::mem::take(&mut embeddings[0]) - }; - - let embeddings_table_name = format!("{}.{}_embeddings", self.name, pipeline.name); - sqlx::query_as(&query_builder!( - queries::VECTOR_SEARCH, - embeddings_table_name, - self.chunks_table_name, - self.documents_table_name - )) - .bind(embedding) - .bind(top_k) - .fetch_all(pool) - .await - .map_err(|e| anyhow::anyhow!(e)) } #[instrument(skip(self))] pub async fn archive(&mut self) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; + let pipelines = self.get_pipelines().await?; let timestamp = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .expect("Error getting system time") .as_secs(); - let archive_table_name = format!("{}_archive_{}", &self.name, timestamp); + let collection_archive_name = format!("{}_archive_{}", &self.name, timestamp); let mut transaciton = pool.begin().await?; + // Change name in pgml.collections sqlx::query("UPDATE pgml.collections SET name = $1, active = FALSE where name = $2") - .bind(&archive_table_name) + .bind(&collection_archive_name) .bind(&self.name) .execute(&mut *transaciton) .await?; + // Change collection_pipeline schema + for pipeline in pipelines { + sqlx::query(&query_builder!( + "ALTER SCHEMA %s RENAME TO %s", + format!("{}_{}", self.name, pipeline.name), + format!("{}_{}", collection_archive_name, pipeline.name) + )) + .execute(&mut *transaciton) + .await?; + } + // Change collection schema sqlx::query(&query_builder!( "ALTER SCHEMA %s RENAME TO %s", &self.name, - archive_table_name + collection_archive_name )) .execute(&mut *transaciton) .await?; @@ -1086,145 +814,35 @@ impl Collection { } /// Gets all pipelines for the [Collection] - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let pipelines = collection.get_pipelines().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn get_pipelines(&mut self) -> anyhow::Result> { self.verify_in_database(false).await?; let pool = get_or_initialize_pool(&self.database_url).await?; - - let pipelines_with_models_and_splitters: Vec = - sqlx::query_as(&query_builder!( - r#"SELECT - p.id as pipeline_id, - p.name as pipeline_name, - p.created_at as pipeline_created_at, - p.active as pipeline_active, - p.parameters as pipeline_parameters, - m.id as model_id, - m.created_at as model_created_at, - m.runtime::TEXT as model_runtime, - m.hyperparams as model_hyperparams, - s.id as splitter_id, - s.created_at as splitter_created_at, - s.name as splitter_name, - s.parameters as splitter_parameters - FROM - %s p - INNER JOIN pgml.models m ON p.model_id = m.id - INNER JOIN pgml.splitters s ON p.splitter_id = s.id - WHERE - p.active = TRUE - "#, - self.pipelines_table_name - )) - .fetch_all(&pool) - .await?; - - let pipelines: Vec = pipelines_with_models_and_splitters - .into_iter() - .map(|p| { - let mut pipeline: Pipeline = p.into(); - pipeline.set_project_info( - self.database_data - .as_ref() - .expect("Collection must be verified to get all pipelines") - .project_info - .clone(), - ); - pipeline - }) - .collect(); - Ok(pipelines) + let pipelines: Vec = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE active = TRUE", + self.pipelines_table_name + )) + .fetch_all(&pool) + .await?; + pipelines.into_iter().map(|p| p.try_into()).collect() } /// Gets a [Pipeline] by name - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let pipeline = collection.get_pipeline("my_pipeline").await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { self.verify_in_database(false).await?; let pool = get_or_initialize_pool(&self.database_url).await?; - - let pipeline_with_model_and_splitter: models::PipelineWithModelAndSplitter = - sqlx::query_as(&query_builder!( - r#"SELECT - p.id as pipeline_id, - p.name as pipeline_name, - p.created_at as pipeline_created_at, - p.active as pipeline_active, - p.parameters as pipeline_parameters, - m.id as model_id, - m.created_at as model_created_at, - m.runtime::TEXT as model_runtime, - m.hyperparams as model_hyperparams, - s.id as splitter_id, - s.created_at as splitter_created_at, - s.name as splitter_name, - s.parameters as splitter_parameters - FROM - %s p - INNER JOIN pgml.models m ON p.model_id = m.id - INNER JOIN pgml.splitters s ON p.splitter_id = s.id - WHERE - p.active = TRUE - AND p.name = $1 - "#, - self.pipelines_table_name - )) - .bind(name) - .fetch_one(&pool) - .await?; - - let mut pipeline: Pipeline = pipeline_with_model_and_splitter.into(); - pipeline.set_project_info(self.database_data.as_ref().unwrap().project_info.clone()); - Ok(pipeline) - } - - #[instrument(skip(self))] - pub(crate) async fn get_project_info(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - Ok(self - .database_data - .as_ref() - .context("Collection must be verified to get project info")? - .project_info - .clone()) + let pipeline: models::Pipeline = sqlx::query_as(&query_builder!( + "SELECT * FROM %s WHERE name = $1 AND active = TRUE LIMIT 1", + self.pipelines_table_name + )) + .bind(name) + .fetch_one(&pool) + .await?; + pipeline.try_into() } /// Check if the [Collection] exists in the database - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let collection = Collection::new("my_collection", None); - /// let exists = collection.exists().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] pub async fn exists(&self) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -1312,6 +930,136 @@ impl Collection { Ok(()) } + #[instrument(skip(self))] + pub async fn get_pipeline_status(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + pipeline.get_status(project_info, &pool).await + } + + #[instrument(skip(self))] + pub async fn generate_er_diagram(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { + self.verify_in_database(false).await?; + let project_info = &self.database_data.as_ref().unwrap().project_info; + let pool = get_or_initialize_pool(&self.database_url).await?; + pipeline + .verify_in_database(project_info, false, &pool) + .await?; + + let parsed_schema = pipeline + .parsed_schema + .as_ref() + .context("Pipeline must have schema to generate er diagram")?; + + let mut uml_entites = format!( + r#" +@startuml +' hide the spot +' hide circle + +' avoid problems with angled crows feet +skinparam linetype ortho + +entity "pgml.collections" as pgmlc {{ + id : bigint + -- + created_at : timestamp without time zone + name : text + active : boolean + project_id : bigint + sdk_version : text +}} + +entity "{}.documents" as documents {{ + id : bigint + -- + created_at : timestamp without time zone + source_uuid : uuid + document : jsonb +}} + +entity "{}.pipelines" as pipelines {{ + id : bigint + -- + created_at : timestamp without time zone + name : text + active : boolean + schema : jsonb +}} + "#, + self.name, self.name + ); + + let schema = format!("{}_{}", self.name, pipeline.name); + + let mut uml_relations = r#" +pgmlc ||..|| pipelines + "# + .to_string(); + + for (key, field_action) in parsed_schema.iter() { + let nice_name_key = key.replace(' ', "_"); + + let relations = format!( + r#" +documents ||..|{{ {nice_name_key}_chunks +{nice_name_key}_chunks ||.|| {nice_name_key}_embeddings + "# + ); + uml_relations.push_str(&relations); + + if let Some(_embed_action) = &field_action.semantic_search { + let entites = format!( + r#" +entity "{schema}.{key}_chunks" as {nice_name_key}_chunks {{ + id : bigint + -- + created_at : timestamp without time zone + document_id : bigint + chunk_index : bigint + chunk : text +}} + +entity "{schema}.{key}_embeddings" as {nice_name_key}_embeddings {{ + id : bigint + -- + created_at : timestamp without time zone + chunk_id : bigint + embedding : vector +}} + "# + ); + uml_entites.push_str(&entites); + } + + if let Some(_full_text_search_action) = &field_action.full_text_search { + let entites = format!( + r#" +entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ + id : bigint + -- + created_at : timestamp without time zone + chunk_id : bigint + tsvectors : tsvector +}} + "# + ); + uml_entites.push_str(&entites); + + let relations = format!( + r#" +{nice_name_key}_chunks ||..|| {nice_name_key}_tsvectors + "# + ); + uml_relations.push_str(&relations); + } + } + + uml_entites.push_str(¨_relations); + Ok(uml_entites) + } + pub async fn upsert_file(&mut self, path: &str) -> anyhow::Result<()> { self.verify_in_database(false).await?; let path = Path::new(path); @@ -1323,11 +1071,10 @@ impl Collection { self.upsert_documents(vec![document.into()], None).await } - fn generate_table_names(name: &str) -> (String, String, String, String, String) { + fn generate_table_names(name: &str) -> (String, String, String, String) { [ ".pipelines", ".documents", - ".transforms", ".chunks", ".documents_tsvectors", ] diff --git a/pgml-sdks/pgml/src/filter_builder.rs b/pgml-sdks/pgml/src/filter_builder.rs index 32b9f4126..947f04bfc 100644 --- a/pgml-sdks/pgml/src/filter_builder.rs +++ b/pgml-sdks/pgml/src/filter_builder.rs @@ -1,49 +1,8 @@ -use sea_query::{ - extension::postgres::PgExpr, value::ArrayType, Condition, Expr, IntoCondition, SimpleExpr, -}; - -fn get_sea_query_array_type(value: &serde_json::Value) -> ArrayType { - if value.is_null() { - panic!("Invalid metadata filter configuration") - } else if value.is_string() { - ArrayType::String - } else if value.is_i64() || value.is_u64() { - ArrayType::BigInt - } else if value.is_f64() { - ArrayType::Double - } else if value.is_boolean() { - ArrayType::Bool - } else if value.is_array() { - let value = value - .as_array() - .expect("Invalid metadata filter configuration"); - get_sea_query_array_type(&value[0]) - } else { - panic!("Invalid metadata filter configuration") - } -} +use anyhow::Context; +use sea_query::{extension::postgres::PgExpr, Condition, Expr, IntoCondition, SimpleExpr}; fn serde_value_to_sea_query_value(value: &serde_json::Value) -> sea_query::Value { - if value.is_string() { - sea_query::Value::String(Some(Box::new(value.as_str().unwrap().to_string()))) - } else if value.is_i64() { - sea_query::Value::BigInt(Some(value.as_i64().unwrap())) - } else if value.is_f64() { - sea_query::Value::Double(Some(value.as_f64().unwrap())) - } else if value.is_boolean() { - sea_query::Value::Bool(Some(value.as_bool().unwrap())) - } else if value.is_array() { - let value = value.as_array().unwrap(); - let ty = get_sea_query_array_type(&value[0]); - let value = Some(Box::new( - value.iter().map(serde_value_to_sea_query_value).collect(), - )); - sea_query::Value::Array(ty, value) - } else if value.is_object() { - sea_query::Value::Json(Some(Box::new(value.clone()))) - } else { - panic!("Invalid metadata filter configuration") - } + sea_query::Value::Json(Some(Box::new(value.clone()))) } fn reconstruct_json(path: Vec, value: serde_json::Value) -> serde_json::Value { @@ -102,36 +61,13 @@ fn value_is_object_and_is_comparison_operator(value: &serde_json::Value) -> bool }) } -fn get_value_type(value: &serde_json::Value) -> String { - if value.is_object() { - let (_, value) = value - .as_object() - .expect("Invalid metadata filter configuration") - .iter() - .next() - .unwrap(); - get_value_type(value) - } else if value.is_array() { - let value = &value.as_array().unwrap()[0]; - get_value_type(value) - } else if value.is_string() { - "text".to_string() - } else if value.is_i64() || value.is_f64() { - "float8".to_string() - } else if value.is_boolean() { - "bool".to_string() - } else { - panic!("Invalid metadata filter configuration") - } -} - fn build_recursive<'a>( table_name: &'a str, column_name: &'a str, path: Vec, filter: serde_json::Value, condition: Option, -) -> Condition { +) -> anyhow::Result { if filter.is_object() { let mut condition = condition.unwrap_or(Condition::all()); for (key, value) in filter.as_object().unwrap() { @@ -175,46 +111,43 @@ fn build_recursive<'a>( expression .contains(Expr::val(serde_value_to_sea_query_value(&json))) } else { - expression - .not() - .contains(Expr::val(serde_value_to_sea_query_value(&json))) + let expression = expression + .contains(Expr::val(serde_value_to_sea_query_value(&json))); + expression.not() } } else { - // If we are not checking whether two values are equal or not equal, we need to cast it to the correct type before doing the comparison - let ty = get_value_type(value); let expression = Expr::cust( format!( - "(\"{}\".\"{}\"#>>'{{{}}}')::{}", + "\"{}\".\"{}\"#>'{{{}}}'", table_name, column_name, - local_path.join(","), - ty + local_path.join(",") ) .as_str(), ); let expression = Expr::expr(expression); build_expression(expression, value.clone()) }; - expression.into_condition() + Ok(expression.into_condition()) } else { build_recursive(table_name, column_name, local_path, value.clone(), None) } } - }; + }?; condition = condition.add(sub_condition); } - condition + Ok(condition) } else if filter.is_array() { - let mut condition = condition.expect("Invalid metadata filter configuration"); + let mut condition = condition.context("Invalid metadata filter configuration")?; for value in filter.as_array().unwrap() { let local_path = path.clone(); let new_condition = - build_recursive(table_name, column_name, local_path, value.clone(), None); + build_recursive(table_name, column_name, local_path, value.clone(), None)?; condition = condition.add(new_condition); } - condition + Ok(condition) } else { - panic!("Invalid metadata filter configuration") + anyhow::bail!("Invalid metadata filter configuration") } } @@ -233,7 +166,7 @@ impl<'a> FilterBuilder<'a> { } } - pub fn build(self) -> Condition { + pub fn build(self) -> anyhow::Result { build_recursive( self.table_name, self.column_name, @@ -276,39 +209,41 @@ mod tests { } #[test] - fn eq_operator() { + fn eq_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "id": {"$eq": 1}, "id2": {"id3": {"$eq": "test"}}, "id4": {"id5": {"id6": {"$eq": true}}}, "id7": {"id8": {"id9": {"id10": {"$eq": [1, 2, 3]}}}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}' AND ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# ); + Ok(()) } #[test] - fn ne_operator() { + fn ne_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "id": {"$ne": 1}, "id2": {"id3": {"$ne": "test"}}, "id4": {"id5": {"id6": {"$ne": true}}}, "id7": {"id8": {"id9": {"id10": {"$ne": [1, 2, 3]}}}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE NOT "test_table"."metadata" @> E'{\"id\":1}' AND NOT "test_table"."metadata" @> E'{\"id2\":{\"id3\":\"test\"}}' AND NOT "test_table"."metadata" @> E'{\"id4\":{\"id5\":{\"id6\":true}}}' AND NOT "test_table"."metadata" @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}'"# + r#"SELECT "id" FROM "test_table" WHERE (NOT ("test_table"."metadata") @> E'{\"id\":1}') AND (NOT ("test_table"."metadata") @> E'{\"id2\":{\"id3\":\"test\"}}') AND (NOT ("test_table"."metadata") @> E'{\"id4\":{\"id5\":{\"id6\":true}}}') AND (NOT ("test_table"."metadata") @> E'{\"id7\":{\"id8\":{\"id9\":{\"id10\":[1,2,3]}}}}')"# ); + Ok(()) } #[test] - fn numeric_comparison_operators() { + fn numeric_comparison_operators() -> anyhow::Result<()> { let basic_comparison_operators = vec![">", ">=", "<", "<="]; let basic_comparison_operators_names = vec!["$gt", "$gte", "$lt", "$lte"]; for (operator, name) in basic_comparison_operators @@ -319,20 +254,22 @@ mod tests { "id": {name: 1}, "id2": {"id3": {name: 1}} })) - .build() + .build()? .to_valid_sql_query(); + println!("{sql}"); assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} 1 AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} 1"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>'{{id}}') {} '1' AND ("test_table"."metadata"#>'{{id2,id3}}') {} '1'"##, operator, operator ) ); } + Ok(()) } #[test] - fn array_comparison_operators() { + fn array_comparison_operators() -> anyhow::Result<()> { let array_comparison_operators = vec!["IN", "NOT IN"]; let array_comparison_operators_names = vec!["$in", "$nin"]; for (operator, name) in array_comparison_operators @@ -343,68 +280,72 @@ mod tests { "id": {name: [1]}, "id2": {"id3": {name: [1]}} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, format!( - r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>>'{{id}}')::float8 {} (1) AND ("test_table"."metadata"#>>'{{id2,id3}}')::float8 {} (1)"##, + r##"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata"#>'{{id}}') {} ('1') AND ("test_table"."metadata"#>'{{id2,id3}}') {} ('1')"##, operator, operator ) ); } + Ok(()) } #[test] - fn and_operator() { + fn and_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$and": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"# ); + Ok(()) } #[test] - fn or_operator() { + fn or_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$or": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}'"# ); + Ok(()) } #[test] - fn not_operator() { + fn not_operator() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$not": [ {"id": {"$eq": 1}}, {"id2": {"id3": {"$eq": 1}}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE NOT ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}')"# + r#"SELECT "id" FROM "test_table" WHERE NOT (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}')"# ); + Ok(()) } #[test] - fn random_difficult_tests() { + fn filter_builder_random_difficult_tests() -> anyhow::Result<()> { let sql = construct_filter_builder_with_json(json!({ "$and": [ {"$or": [ @@ -415,11 +356,11 @@ mod tests { {"id4": {"$eq": 1}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' OR "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') AND "test_table"."metadata" @> E'{\"id4\":1}'"# + r#"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' OR ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') AND ("test_table"."metadata") @> E'{\"id4\":1}'"# ); let sql = construct_filter_builder_with_json(json!({ "$or": [ @@ -431,11 +372,11 @@ mod tests { {"id4": {"$eq": 1}} ] })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata" @> E'{\"id\":1}' AND "test_table"."metadata" @> E'{\"id2\":{\"id3\":1}}') OR "test_table"."metadata" @> E'{\"id4\":1}'"# + r#"SELECT "id" FROM "test_table" WHERE (("test_table"."metadata") @> E'{\"id\":1}' AND ("test_table"."metadata") @> E'{\"id2\":{\"id3\":1}}') OR ("test_table"."metadata") @> E'{\"id4\":1}'"# ); let sql = construct_filter_builder_with_json(json!({ "metadata": {"$or": [ @@ -443,11 +384,12 @@ mod tests { {"uuid2": {"$eq": "2"}} ]} })) - .build() + .build()? .to_valid_sql_query(); assert_eq!( sql, - r#"SELECT "id" FROM "test_table" WHERE "test_table"."metadata" @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR "test_table"."metadata" @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"# + r#"SELECT "id" FROM "test_table" WHERE ("test_table"."metadata") @> E'{\"metadata\":{\"uuid\":\"1\"}}' OR ("test_table"."metadata") @> E'{\"metadata\":{\"uuid2\":\"2\"}}'"# ); + Ok(()) } } diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index c49b5c493..f8de14587 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -4,10 +4,7 @@ use rust_bridge::javascript::{FromJsType, IntoJsResult}; use std::cell::RefCell; use std::sync::Arc; -use crate::{ - pipeline::PipelineSyncData, - types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, -}; +use crate::types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}; //////////////////////////////////////////////////////////////////////////////// // Rust to JS ////////////////////////////////////////////////////////////////// @@ -63,16 +60,6 @@ impl IntoJsResult for Json { } } -impl IntoJsResult for PipelineSyncData { - type Output = JsValue; - fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>( - self, - cx: &mut C, - ) -> JsResult<'b, Self::Output> { - Json::from(self).into_js_result(cx) - } -} - #[derive(Clone)] struct GeneralJsonAsyncIteratorJavaScript(Arc>); diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index 9d19b16bd..300091500 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -4,12 +4,7 @@ use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString}; use pyo3::{prelude::*, types::PyBool}; use std::sync::Arc; -use rust_bridge::python::CustomInto; - -use crate::{ - pipeline::PipelineSyncData, - types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, -}; +use crate::types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}; //////////////////////////////////////////////////////////////////////////////// // Rust to PY ////////////////////////////////////////////////////////////////// @@ -50,12 +45,6 @@ impl IntoPy for Json { } } -impl IntoPy for PipelineSyncData { - fn into_py(self, py: Python) -> PyObject { - Json::from(self).into_py(py) - } -} - #[pyclass] #[derive(Clone)] struct GeneralJsonAsyncIteratorPython { @@ -177,13 +166,6 @@ impl FromPyObject<'_> for Json { } } -impl FromPyObject<'_> for PipelineSyncData { - fn extract(ob: &PyAny) -> PyResult { - let json = Json::extract(ob)?; - Ok(json.into()) - } -} - impl FromPyObject<'_> for GeneralJsonAsyncIterator { fn extract(_ob: &PyAny) -> PyResult { panic!("We must implement this, but this is impossible to be reached") @@ -199,9 +181,3 @@ impl FromPyObject<'_> for GeneralJsonIterator { //////////////////////////////////////////////////////////////////////////////// // Rust to Rust ////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// - -impl CustomInto for PipelineSyncData { - fn custom_into(self) -> Json { - Json::from(self) - } -} diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index cef33c024..50665ed93 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -8,7 +8,7 @@ use parking_lot::RwLock; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::collections::HashMap; use std::env; -use tokio::runtime::Runtime; +use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -28,10 +28,13 @@ mod queries; mod query_builder; mod query_runner; mod remote_embeddings; +mod search_query_builder; +mod single_field_pipeline; mod splitter; pub mod transformer_pipeline; pub mod types; mod utils; +mod vector_search_query_builder; // Re-export pub use builtins::Builtins; @@ -43,7 +46,9 @@ pub use splitter::Splitter; pub use transformer_pipeline::TransformerPipeline; // This is use when inserting collections to set the sdk_version used during creation -static SDK_VERSION: &str = "0.9.2"; +// This doesn't actually mean the verion of the SDK it was created on, it means the +// version it is compatible with +static SDK_VERSION: &str = "1.0.0"; // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need @@ -54,12 +59,11 @@ static DATABASE_POOLS: RwLock>> = RwLock::new(Non async fn get_or_initialize_pool(database_url: &Option) -> anyhow::Result { let mut pools = DATABASE_POOLS.write(); let pools = pools.get_or_insert_with(HashMap::new); - let environment_url = std::env::var("DATABASE_URL"); - let environment_url = environment_url.as_deref(); - let url = database_url - .as_deref() - .unwrap_or_else(|| environment_url.expect("Please set DATABASE_URL environment variable")); - if let Some(pool) = pools.get(url) { + let url = database_url.clone().unwrap_or_else(|| { + std::env::var("PGML_DATABASE_URL").unwrap_or_else(|_| + std::env::var("DATABASE_URL").expect("Please set PGML_DATABASE_URL environment variable or explicitly pass a database connection string to your collection")) + }); + if let Some(pool) = pools.get(&url) { Ok(pool.clone()) } else { let timeout = std::env::var("PGML_CHECKOUT_TIMEOUT") @@ -128,7 +132,11 @@ fn get_or_set_runtime<'a>() -> &'a Runtime { if let Some(r) = &RUNTIME { r } else { - let runtime = Runtime::new().unwrap(); + // Need to use multi thread for JavaScript + let runtime = Builder::new_multi_thread() + .enable_all() + .build() + .expect("Error creating tokio runtime"); RUNTIME = Some(runtime); get_or_set_runtime() } @@ -157,6 +165,10 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_function(pyo3::wrap_pyfunction!(cli::cli, m)?)?; + m.add_function(pyo3::wrap_pyfunction!( + single_field_pipeline::SingleFieldPipeline, + m + )?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -204,6 +216,10 @@ fn migrate( fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { cx.export_function("init_logger", init_logger)?; cx.export_function("migrate", migrate)?; + cx.export_function( + "newSingleFieldPipeline", + single_field_pipeline::SingleFieldPipeline, + )?; cx.export_function("cli", cli::cli)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; @@ -224,16 +240,27 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { #[cfg(test)] mod tests { use super::*; - use crate::{model::Model, pipeline::Pipeline, splitter::Splitter, types::Json}; + use crate::types::Json; use serde_json::json; fn generate_dummy_documents(count: usize) -> Vec { let mut documents = Vec::new(); for i in 0..count { + let body_text = vec![format!( + "Here is some text that we will end up splitting on! {i}" + )] + .into_iter() + .cycle() + .take(100) + .collect::>() + .join("\n"); let document = serde_json::json!( { "id": i, - "text": format!("This is a test document: {}", i), + "title": format!("Test document: {}", i), + "body": body_text, + "text": "here is some test text", + "notes": format!("Here are some notes or something for test document {}", i), "metadata": { "uuid": i * 10, "name": format!("Test Document {}", i) @@ -248,10 +275,10 @@ mod tests { // Collection & Pipelines ///// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_create_collection() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_ccc_0", None); + let mut collection = Collection::new("test_r_c_ccc_0", None)?; assert!(collection.database_data.is_none()); collection.verify_in_database(false).await?; assert!(collection.database_data.is_some()); @@ -259,525 +286,960 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let mut pipeline = Pipeline::new("test_p_carp_58", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_carp_1", None)?; + assert!(collection.database_data.is_none()); + collection.add_pipeline(&mut pipeline).await?; + assert!(collection.database_data.is_some()); + collection.remove_pipeline(&pipeline).await?; + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.is_empty()); + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn can_add_remove_pipelines() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut pipeline1 = Pipeline::new("test_r_p_carps_1", Some(json!({}).into()))?; + let mut pipeline2 = Pipeline::new("test_r_p_carps_2", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_carps_11", None)?; + collection.add_pipeline(&mut pipeline1).await?; + collection.add_pipeline(&mut pipeline2).await?; + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.len() == 2); + collection.remove_pipeline(&pipeline1).await?; + let pipelines = collection.get_pipelines().await?; + assert!(pipelines.len() == 1); + assert!(collection.get_pipeline("test_r_p_carps_1").await.is_err()); + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn can_add_pipeline_and_upsert_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_capaud_107"; + let pipeline_name = "test_r_p_capaud_6"; let mut pipeline = Pipeline::new( - "test_p_cap_57", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character", + "parameters": { + "chunk_size": 1000, + "chunk_overlap": 40 + } + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_carp_3", None); - assert!(collection.database_data.is_none()); + )?; + let mut collection = Collection::new(collection_name, None)?; collection.add_pipeline(&mut pipeline).await?; - assert!(collection.database_data.is_some()); - collection.remove_pipeline(&mut pipeline).await?; - let pipelines = collection.get_pipelines().await?; - assert!(pipelines.is_empty()); + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pool = get_or_initialize_pool(&None).await?; + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 2); + for (d, qd) in std::iter::zip(documents, queried_documents) { + assert_eq!(d, qd.document); + } + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); collection.archive().await?; Ok(()) } - // #[sqlx::test] - // async fn can_add_remove_pipelines() -> anyhow::Result<()> { - // internal_init_logger(None, None).ok(); - // let model = Model::default(); - // let splitter = Splitter::default(); - // let mut pipeline1 = Pipeline::new( - // "test_r_p_carps_0", - // Some(model.clone()), - // Some(splitter.clone()), - // None, - // ); - // let mut pipeline2 = Pipeline::new("test_r_p_carps_1", Some(model), Some(splitter), None); - // let mut collection = Collection::new("test_r_c_carps_1", None); - // collection.add_pipeline(&mut pipeline1).await?; - // collection.add_pipeline(&mut pipeline2).await?; - // let pipelines = collection.get_pipelines().await?; - // assert!(pipelines.len() == 2); - // collection.remove_pipeline(&mut pipeline1).await?; - // let pipelines = collection.get_pipelines().await?; - // assert!(pipelines.len() == 1); - // assert!(collection.get_pipeline("test_r_p_carps_0").await.is_err()); - // collection.archive().await?; - // Ok(()) - // } - - #[sqlx::test] - async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { + #[tokio::test] + async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let collection_name = "test_r_c_cudaap_51"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cudaap_9"; let mut pipeline = Pipeline::new( - "test_r_p_cschpfp_0", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "hnsw": { - "m": 100, - "ef_construction": 200 + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let collection_name = "test_r_c_cschpfp_1"; - let mut collection = Collection::new(collection_name, None); + )?; collection.add_pipeline(&mut pipeline).await?; - let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; - let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; let pool = get_or_initialize_pool(&None).await?; - let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( - "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", - embeddings_table_name, - collection_name - )).fetch_all(&pool).await?; - let names = results.iter().map(|(name, _)| name).collect::>(); - let definitions = results - .iter() - .map(|(_, definition)| definition) - .collect::>(); - assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); - assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 2); + for (d, qd) in std::iter::zip(documents, queried_documents) { + assert_eq!(d, qd.document); + } + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 4); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 4); + collection.archive().await?; Ok(()) } - #[sqlx::test] + #[tokio::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_p_dep_0", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_dep_1", None); + let mut pipeline = Pipeline::new("test_p_dep_1", Some(json!({}).into()))?; + let mut collection = Collection::new("test_r_c_dep_1", None)?; collection.add_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; assert_eq!(pipeline.name, queried_pipeline.name); collection.disable_pipeline(&pipeline).await?; let queried_pipelines = &collection.get_pipelines().await?; assert!(queried_pipelines.is_empty()); - collection.enable_pipeline(&pipeline).await?; + collection.enable_pipeline(&mut pipeline).await?; let queried_pipeline = &collection.get_pipelines().await?[0]; assert_eq!(pipeline.name, queried_pipeline.name); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn sync_multiple_pipelines() -> anyhow::Result<()> { + #[tokio::test] + async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline1 = Pipeline::new( - "test_r_p_smp_0", - Some(model.clone()), - Some(splitter.clone()), + let collection_name = "test_r_c_cudaep_43"; + let mut collection = Collection::new(collection_name, None)?; + let pipeline_name = "test_r_p_cudaep_9"; + let mut pipeline = Pipeline::new( + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } } }) .into(), ), - ); - let mut pipeline2 = Pipeline::new( - "test_r_p_smp_1", - Some(model), - Some(splitter), + )?; + collection.add_pipeline(&mut pipeline).await?; + collection.disable_pipeline(&pipeline).await?; + let documents = generate_dummy_documents(2); + collection.upsert_documents(documents, None).await?; + let pool = get_or_initialize_pool(&None).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.is_empty()); + collection.enable_pipeline(&mut pipeline).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 2); + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn random_pipelines_documents_test() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_rpdt_3"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(6); + collection + .upsert_documents(documents[..2].to_owned(), None) + .await?; + let pipeline_name1 = "test_r_p_rpdt1_0"; + let mut pipeline = Pipeline::new( + pipeline_name1, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_smp_3", None); - collection.add_pipeline(&mut pipeline1).await?; - collection.add_pipeline(&mut pipeline2).await?; + )?; + collection.add_pipeline(&mut pipeline).await?; + collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(documents[2..4].to_owned(), None) .await?; - let status_1 = pipeline1.get_status().await?; - let status_2 = pipeline2.get_status().await?; - assert!( - status_1.chunks_status.synced == status_1.chunks_status.total - && status_1.chunks_status.not_synced == 0 - ); - assert!( - status_2.chunks_status.synced == status_2.chunks_status.total - && status_2.chunks_status.not_synced == 0 - ); - collection.archive().await?; - Ok(()) - } - /////////////////////////////// - // Various Searches /////////// - /////////////////////////////// + let pool = get_or_initialize_pool(&None).await?; + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name1); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 4); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name1); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 8); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name1); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 8); - #[sqlx::test] - async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let pipeline_name2 = "test_r_p_rpdt2_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswle_1", - Some(model), - Some(splitter), + pipeline_name2, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small", + }, + "full_text_search": { + "configuration": "english" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswle_28", None); + )?; collection.add_pipeline(&mut pipeline).await?; - // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswle_1", None, None, None); + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name2); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 4); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name2); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 8); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name2); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 8); + collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(documents[4..6].to_owned(), None) .await?; - let results = collection - .vector_search("Here is some query", &mut pipeline, None, None) - .await?; - assert!(results.len() == 3); + + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name2); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 6); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name2); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name2); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); + + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name1); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 6); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name1); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 12); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name1); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 12); + collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { + #[tokio::test] + async fn pipeline_sync_status() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); + let collection_name = "test_r_c_pss_5"; + let mut collection = Collection::new(collection_name, None)?; + let pipeline_name = "test_r_p_pss_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswre_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + }, + "splitter": { + "model": "recursive_character" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswre_21", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswre_1", None, None, None); + let documents = generate_dummy_documents(4); collection - .upsert_documents(generate_dummy_documents(3), None) + .upsert_documents(documents[..2].to_owned(), None) .await?; - let results = collection - .vector_search("Here is some query", &mut pipeline, None, Some(10)) + let status = collection.get_pipeline_status(&mut pipeline).await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "embeddings": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "tsvectors": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + } + }) + ); + collection.disable_pipeline(&pipeline).await?; + collection + .upsert_documents(documents[2..4].to_owned(), None) .await?; - assert!(results.len() == 3); + let status = collection.get_pipeline_status(&mut pipeline).await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 2, + "synced": 2, + "total": 4 + }, + "embeddings": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + "tsvectors": { + "not_synced": 0, + "synced": 2, + "total": 2 + }, + } + }) + ); + collection.enable_pipeline(&mut pipeline).await?; + let status = collection.get_pipeline_status(&mut pipeline).await?; + assert_eq!( + status.0, + json!({ + "title": { + "chunks": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + "embeddings": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + "tsvectors": { + "not_synced": 0, + "synced": 4, + "total": 4 + }, + } + }) + ); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { + #[tokio::test] + async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let collection_name = "test_r_c_cschpfp_4"; + let mut collection = Collection::new(collection_name, None)?; + let pipeline_name = "test_r_p_cschpfp_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqb_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small", + "hnsw": { + "m": 100, + "ef_construction": 200 + } + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqb_4", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(4), None) - .await?; - let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .limit(3) - .fetch_all() - .await?; - assert!(results.len() == 3); + let schema = format!("{collection_name}_{pipeline_name}"); + let full_embeddings_table_name = format!("{schema}.title_embeddings"); + let embeddings_table_name = full_embeddings_table_name.split('.').collect::>()[1]; + let pool = get_or_initialize_pool(&None).await?; + let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( + "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", + embeddings_table_name, + schema + )).fetch_all(&pool).await?; + let names = results.iter().map(|(name, _)| name).collect::>(); + let definitions = results + .iter() + .map(|(_, definition)| definition) + .collect::>(); + assert!(names.contains(&&"title_pipeline_embedding_hnsw_vector_index".to_string())); + assert!(definitions.contains(&&format!("CREATE INDEX title_pipeline_embedding_hnsw_vector_index ON {full_embeddings_table_name} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')"))); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( - ) -> anyhow::Result<()> { + /////////////////////////////// + // Searches /////////////////// + /////////////////////////////// + + #[tokio::test] + async fn can_search_with_local_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("hkunlp/instructor-base".to_string()), - Some("python".to_string()), - Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), - ); - let splitter = Splitter::default(); + let collection_name = "test_r_c_cswle_121"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cswle_9"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqbapmpis_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small" + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } + }, + "notes": { + "semantic_search": { + "model": "intfloat/e5-small" + } } }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); + )?; collection.add_pipeline(&mut pipeline).await?; + let query = json!({ + "query": { + "full_text_search": { + "title": { + "query": "test 9", + "boost": 4.0 + }, + "body": { + "query": "Test", + "boost": 1.2 + } + }, + "semantic_search": { + "title": { + "query": "This is a test", + "boost": 2.0 + }, + "body": { + "query": "This is the body test", + "parameters": { + "instruction": "Represent the Wikipedia question for retrieving supporting documents: ", + }, + "boost": 1.01 + }, + "notes": { + "query": "This is the notes test", + "boost": 1.01 + } + }, + "filter": { + "id": { + "$gt": 1 + } + } - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; + }, + "limit": 5 + }); let results = collection - .query() - .vector_recall( - "Here is some query", + .search(query.clone().into(), &mut pipeline) + .await?; + let ids: Vec = results["results"] + .as_array() + .unwrap() + .iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![9, 2, 7, 8, 3]); + + let pool = get_or_initialize_pool(&None).await?; + + let searches_table = format!("{}_{}.searches", collection_name, pipeline_name); + let searches: Vec<(i64, serde_json::Value)> = + sqlx::query_as(&query_builder!("SELECT id, query FROM %s", searches_table)) + .fetch_all(&pool) + .await?; + assert!(searches.len() == 1); + assert!(searches[0].0 == results["search_id"].as_i64().unwrap()); + assert!(searches[0].1 == query); + + let search_results_table = format!("{}_{}.search_results", collection_name, pipeline_name); + let search_results: Vec<(i64, i64, i64, serde_json::Value, i32)> = + sqlx::query_as(&query_builder!( + "SELECT id, search_id, document_id, scores, rank FROM %s ORDER BY rank ASC", + search_results_table + )) + .fetch_all(&pool) + .await?; + assert!(search_results.len() == 5); + // Document ids are 1 based in the db not 0 based like they are here + assert_eq!( + search_results.iter().map(|sr| sr.2).collect::>(), + vec![10, 3, 8, 9, 4] + ); + + let event = json!({"clicked": true}); + collection + .add_search_event( + results["search_id"].as_i64().unwrap(), + 2, + event.clone().into(), &pipeline, - Some( - json!({ - "instruction": "Represent the Wikipedia document for retrieval: " - }) - .into(), - ), ) - .limit(10) - .fetch_all() .await?; - assert!(results.len() == 3); + let search_events_table = format!("{}_{}.search_events", collection_name, pipeline_name); + let (search_result, retrieved_event): (i64, Json) = sqlx::query_as(&query_builder!( + "SELECT search_result, event FROM %s LIMIT 1", + search_events_table + )) + .fetch_one(&pool) + .await?; + assert_eq!(search_result, 2); + assert_eq!(event, retrieved_event.0); + collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { + #[tokio::test] + async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); + let collection_name = "test r_c_cswre_66"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cswre_8"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqbwre_1", - Some(model), - Some(splitter), + pipeline_name, Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "text-embedding-ada-002", + "source": "openai", + }, + "full_text_search": { + "configuration": "english" + } + }, }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(4), None) - .await?; + let mut pipeline = Pipeline::new(pipeline_name, None)?; let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .limit(3) - .fetch_all() + .search( + json!({ + "query": { + "full_text_search": { + "body": { + "query": "Test", + "boost": 1.2 + } + }, + "semantic_search": { + "title": { + "query": "This is a test", + "boost": 2.0 + }, + "body": { + "query": "This is the body test", + "boost": 1.01 + }, + }, + "filter": { + "id": { + "$gt": 1 + } + } + }, + "limit": 5 + }) + .into(), + &mut pipeline, + ) .await?; - assert!(results.len() == 3); + let ids: Vec = results["results"] + .as_array() + .unwrap() + .iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![2, 3, 7, 4, 8]); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( - ) -> anyhow::Result<()> { - internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = - Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); - collection.add_pipeline(&mut pipeline).await?; + /////////////////////////////// + // Vector Searches //////////// + /////////////////////////////// - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; + #[tokio::test] + async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test r_c_cvswle_9"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cvswle_0"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small" + }, + }, + }) + .into(), + ), + )?; + collection.add_pipeline(&mut pipeline).await?; let results = collection - .query() - .vector_recall( - "Here is some query", - &pipeline, - Some( - json!({ - "hnsw": { - "ef_search": 2 + .vector_search( + json!({ + "query": { + "fields": { + "title": { + "query": "Test document: 2", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + }, + "full_text_filter": "test" + }, + "body": { + "query": "Test document: 2" + }, + }, + "filter": { + "id": { + "$gt": 3 + } } - }) - .into(), - ), + }, + "limit": 5 + }) + .into(), + &mut pipeline, ) - .fetch_all() .await?; - assert!(results.len() == 3); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![8, 4, 7, 6, 9]); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( - ) -> anyhow::Result<()> { + #[tokio::test] + async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new( - Some("text-embedding-ada-002".to_string()), - Some("openai".to_string()), - None, - ); - let splitter = Splitter::default(); + let collection_name = "test r_c_cvswre_7"; + let mut collection = Collection::new(collection_name, None)?; + let documents = generate_dummy_documents(10); + collection.upsert_documents(documents.clone(), None).await?; + let pipeline_name = "test_r_p_cvswre_0"; let mut pipeline = Pipeline::new( - "test_r_p_cvswqbachesvare_2", - Some(model), - Some(splitter), - None, - ); - let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "source": "openai", + "model": "text-embedding-ada-002" + }, + }, + }) + .into(), + ), + )?; collection.add_pipeline(&mut pipeline).await?; - - // Recreate the pipeline to replicate a more accurate example - let pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); - collection - .upsert_documents(generate_dummy_documents(3), None) - .await?; + let mut pipeline = Pipeline::new(pipeline_name, None)?; let results = collection - .query() - .vector_recall( - "Here is some query", - &pipeline, - Some( - json!({ - "hnsw": { - "ef_search": 2 + .vector_search( + json!({ + "query": { + "fields": { + "title": { + "full_text_filter": "test", + "query": "Test document: 2" + }, + "body": { + "query": "Test document: 2" + }, + }, + "filter": { + "id": { + "$gt": 3 + } } - }) - .into(), - ), + }, + "limit": 5 + }) + .into(), + &mut pipeline, ) - .fetch_all() .await?; - assert!(results.len() == 3); + let ids: Vec = results + .into_iter() + .map(|r| r["document"]["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6, 7, 9]); collection.archive().await?; Ok(()) } - #[sqlx::test] - async fn can_filter_vector_search() -> anyhow::Result<()> { + #[tokio::test] + async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); + let mut collection = Collection::new("test r_c_cvswqb_7", None)?; let mut pipeline = Pipeline::new( - "test_r_p_cfd_1", - Some(model), - Some(splitter), + "test_r_p_cvswqb_0", Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } + json!({ + "text": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, }) .into(), ), - ); - let mut collection = Collection::new("test_r_c_cfd_2", None); - collection.add_pipeline(&mut pipeline).await?; + )?; collection - .upsert_documents(generate_dummy_documents(5), None) + .upsert_documents(generate_dummy_documents(10), None) .await?; - - let filters = vec![ - (5, json!({}).into()), - ( - 3, + collection.add_pipeline(&mut pipeline).await?; + let results = collection + .query() + .vector_recall("test query", &pipeline, None) + .limit(3) + .filter( json!({ "metadata": { "id": { - "$lt": 3 + "$gt": 3 } - } - }) - .into(), - ), - ( - 1, - json!({ - "full_text_search": { + }, + "full_text": { "configuration": "english", - "text": "1", + "text": "test" } }) .into(), - ), - ]; - - for (expected_result_count, filter) in filters { - let results = collection - .query() - .vector_recall("Here is some query", &pipeline, None) - .filter(filter) - .fetch_all() - .await?; - assert_eq!(results.len(), expected_result_count); - } - + ) + .fetch_all() + .await?; + let ids: Vec = results + .into_iter() + .map(|r| r.2["id"].as_u64().unwrap()) + .collect(); + assert_eq!(ids, vec![4, 5, 6]); collection.archive().await?; Ok(()) } @@ -786,30 +1248,11 @@ mod tests { // Working With Documents ///// /////////////////////////////// - #[sqlx::test] + #[tokio::test] async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cuafgd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - - let mut collection = Collection::new("test_r_c_cuagd_2", None); - collection.add_pipeline(&mut pipeline).await?; + let mut collection = Collection::new("test r_c_cuafgd_1", None)?; - // Test basic upsert let documents = vec![ serde_json::json!({"id": 1, "random_key": 10, "text": "hello world 1"}).into(), serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), @@ -819,7 +1262,6 @@ mod tests { let document = &collection.get_documents(None).await?[0]; assert_eq!(document["document"]["text"], "hello world 1"); - // Test upsert of text and metadata let documents = vec![ serde_json::json!({"id": 1, "text": "hello world new"}).into(), serde_json::json!({"id": 2, "random_key": 12}).into(), @@ -831,58 +1273,38 @@ mod tests { .get_documents(Some( serde_json::json!({ "filter": { - "metadata": { - "random_key": { - "$eq": 12 - } - } - } - }) - .into(), - )) - .await?; - assert_eq!(documents[0]["document"]["text"], "hello world 2"); - - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "metadata": { - "random_key": { - "$gte": 13 - } + "random_key": { + "$eq": 12 } } }) .into(), )) .await?; - assert_eq!(documents[0]["document"]["text"], "hello world 3"); + assert_eq!(documents[0]["document"]["random_key"], 12); let documents = collection .get_documents(Some( serde_json::json!({ "filter": { - "full_text_search": { - "configuration": "english", - "text": "new" + "random_key": { + "$gte": 13 } } }) .into(), )) .await?; - assert_eq!(documents[0]["document"]["text"], "hello world new"); - assert_eq!(documents[0]["document"]["id"].as_i64().unwrap(), 1); + assert_eq!(documents[0]["document"]["random_key"], 13); collection.archive().await?; Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cpgd_2", None); + let mut collection = Collection::new("test_r_c_cpgd_2", None)?; collection .upsert_documents(generate_dummy_documents(10), None) .await?; @@ -961,28 +1383,10 @@ mod tests { Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cfapgd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - - let mut collection = Collection::new("test_r_c_cfapgd_1", None); - collection.add_pipeline(&mut pipeline).await?; + let mut collection = Collection::new("test_r_c_cfapgd_1", None)?; collection .upsert_documents(generate_dummy_documents(10), None) @@ -992,10 +1396,8 @@ mod tests { .get_documents(Some( serde_json::json!({ "filter": { - "metadata": { - "id": { - "$gte": 2 - } + "id": { + "$gte": 2 } }, "limit": 2, @@ -1016,10 +1418,8 @@ mod tests { .get_documents(Some( serde_json::json!({ "filter": { - "metadata": { - "id": { - "$lte": 5 - } + "id": { + "$lte": 5 } }, "limit": 100, @@ -1028,7 +1428,6 @@ mod tests { .into(), )) .await?; - let last_row_id = documents.last().unwrap()["row_id"].as_i64().unwrap(); assert_eq!( documents .into_iter() @@ -1037,55 +1436,14 @@ mod tests { vec![4, 5] ); - let documents = collection - .get_documents(Some( - serde_json::json!({ - "filter": { - "full_text_search": { - "configuration": "english", - "text": "document" - } - }, - "limit": 100, - "last_row_id": last_row_id - }) - .into(), - )) - .await?; - assert_eq!( - documents - .into_iter() - .map(|d| d["document"]["id"].as_i64().unwrap()) - .collect::>(), - vec![6, 7, 8, 9] - ); - collection.archive().await?; Ok(()) } - #[sqlx::test] + #[tokio::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::default(); - let splitter = Splitter::default(); - let mut pipeline = Pipeline::new( - "test_r_p_cfadd_1", - Some(model), - Some(splitter), - Some( - serde_json::json!({ - "full_text_search": { - "active": true, - "configuration": "english" - } - }) - .into(), - ), - ); - - let mut collection = Collection::new("test_r_c_cfadd_1", None); - collection.add_pipeline(&mut pipeline).await?; + let mut collection = Collection::new("test_r_c_cfadd_1", None)?; collection .upsert_documents(generate_dummy_documents(10), None) .await?; @@ -1093,10 +1451,8 @@ mod tests { collection .delete_documents( serde_json::json!({ - "metadata": { - "id": { - "$lt": 2 - } + "id": { + "$lt": 2 } }) .into(), @@ -1111,50 +1467,27 @@ mod tests { collection .delete_documents( serde_json::json!({ - "full_text_search": { - "configuration": "english", - "text": "2" - } - }) - .into(), - ) - .await?; - let documents = collection.get_documents(None).await?; - assert_eq!(documents.len(), 7); - assert!(documents - .iter() - .all(|d| d["document"]["id"].as_i64().unwrap() > 2)); - - collection - .delete_documents( - serde_json::json!({ - "metadata": { - "id": { - "$gte": 6 - } - }, - "full_text_search": { - "configuration": "english", - "text": "6" + "id": { + "$gte": 6 } }) .into(), ) .await?; let documents = collection.get_documents(None).await?; - assert_eq!(documents.len(), 6); + assert_eq!(documents.len(), 4); assert!(documents .iter() - .all(|d| d["document"]["id"].as_i64().unwrap() != 6)); + .all(|d| d["document"]["id"].as_i64().unwrap() < 6)); collection.archive().await?; Ok(()) } - #[sqlx::test] - fn can_order_documents() -> anyhow::Result<()> { + #[tokio::test] + async fn can_order_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cod_1", None); + let mut collection = Collection::new("test_r_c_cod_1", None)?; collection .upsert_documents( vec![ @@ -1231,10 +1564,75 @@ mod tests { Ok(()) } - #[sqlx::test] - fn can_merge_metadata() -> anyhow::Result<()> { + #[tokio::test] + async fn can_update_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test_r_c_cud_5", None)?; + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "text": "Test Document 1" + }) + .into(), + json!({ + "id": 2, + "text": "Test Document 1" + }) + .into(), + json!({ + "id": 3, + "text": "Test Document 1" + }) + .into(), + ], + None, + ) + .await?; + collection + .upsert_documents( + vec![ + json!({ + "id": 1, + "number": 0, + }) + .into(), + json!({ + "id": 2, + "number": 1, + }) + .into(), + json!({ + "id": 3, + "number": 2, + }) + .into(), + ], + None, + ) + .await?; + let documents = collection + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) + .await?; + assert_eq!( + documents + .iter() + .map(|d| d["document"]["number"].as_i64().unwrap()) + .collect::>(), + vec![0, 1, 2] + ); + for document in documents { + assert!(document["document"]["text"].as_str().is_none()); + } + collection.archive().await?; + Ok(()) + } + + #[tokio::test] + async fn can_merge_metadata() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let mut collection = Collection::new("test_r_c_cmm_4", None); + let mut collection = Collection::new("test_r_c_cmm_5", None)?; collection .upsert_documents( vec![ @@ -1276,6 +1674,7 @@ mod tests { .collect::>(), vec![(97, 12), (98, 11), (99, 10)] ); + collection .upsert_documents( vec![ @@ -1300,18 +1699,14 @@ mod tests { ], Some( json!({ - "metadata": { - "merge": true - } + "merge": true }) .into(), ), ) .await?; let documents = collection - .get_documents(Some( - json!({"order_by": {"number": {"number": "asc"}}}).into(), - )) + .get_documents(Some(json!({"order_by": {"number": "asc"}}).into())) .await?; assert_eq!( @@ -1328,4 +1723,52 @@ mod tests { collection.archive().await?; Ok(()) } + + /////////////////////////////// + // ER Diagram ///////////////// + /////////////////////////////// + + #[tokio::test] + async fn generate_er_diagram() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut pipeline = Pipeline::new( + "test_p_ged_57", + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "body": { + "splitter": { + "model": "recursive_character" + }, + "semantic_search": { + "model": "intfloat/e5-small" + }, + "full_text_search": { + "configuration": "english" + } + }, + "notes": { + "semantic_search": { + "model": "intfloat/e5-small" + } + } + }) + .into(), + ), + )?; + let mut collection = Collection::new("test_r_c_ged_2", None)?; + collection.add_pipeline(&mut pipeline).await?; + let diagram = collection.generate_er_diagram(&mut pipeline).await?; + assert!(!diagram.is_empty()); + println!("{diagram}"); + collection.archive().await?; + Ok(()) + } } diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs index b67dec8fa..6133ff1fc 100644 --- a/pgml-sdks/pgml/src/migrations/mod.rs +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -8,6 +8,9 @@ use crate::get_or_initialize_pool; #[path = "pgml--0.9.1--0.9.2.rs"] mod pgml091_092; +#[path = "pgml--0.9.2--1.0.0.rs"] +mod pgml092_100; + // There is probably a better way to write this type and the version_migrations variable in the dispatch_migrations function type MigrateFn = Box) -> BoxFuture<'static, anyhow::Result> + Send + Sync>; @@ -48,8 +51,10 @@ pub fn migrate() -> BoxFuture<'static, anyhow::Result<()>> { async fn dispatch_migrations(pool: PgPool, collections: Vec<(String, i64)>) -> anyhow::Result<()> { // The version of the SDK that the migration was written for, and the migration function - let version_migrations: [(&'static str, MigrateFn); 1] = - [("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed()))]; + let version_migrations: [(&'static str, MigrateFn); 2] = [ + ("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed())), + ("0.9.2", Box::new(|p, c| pgml092_100::migrate(p, c).boxed())), + ]; let mut collections = collections.into_iter().into_group_map(); for (version, migration) in version_migrations.into_iter() { diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs new file mode 100644 index 000000000..29e4f559a --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.2--1.0.0.rs @@ -0,0 +1,9 @@ +use sqlx::PgPool; +use tracing::instrument; + +#[instrument(skip(_pool))] +pub async fn migrate(_pool: PgPool, _: Vec) -> anyhow::Result { + anyhow::bail!( + "There is no automatic migration to SDK version 1.0. Please upgrade the SDK and create a new collection, or contact your PostgresML support to create a migration plan.", + ) +} diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 49197ecf1..ff320c0de 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -1,11 +1,10 @@ -use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sqlx::postgres::PgPool; +use sqlx::{Pool, Postgres}; use tracing::instrument; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, models, + models, types::{DateTime, Json}, }; @@ -45,6 +44,7 @@ impl From<&ModelRuntime> for &'static str { } } +#[allow(dead_code)] #[derive(Debug, Clone)] pub(crate) struct ModelDatabaseData { pub id: i64, @@ -57,7 +57,6 @@ pub struct Model { pub name: String, pub runtime: ModelRuntime, pub parameters: Json, - project_info: Option, pub(crate) database_data: Option, } @@ -93,21 +92,18 @@ impl Model { name, runtime, parameters, - project_info: None, database_data: None, } } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify model without project info"); - let mut parameters = self.parameters.clone(); parameters .as_object_mut() @@ -120,7 +116,7 @@ impl Model { .bind(project_info.id) .bind(Into::<&str>::into(&self.runtime)) .bind(¶meters) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; let model = if let Some(m) = model { @@ -136,7 +132,7 @@ impl Model { .bind("successful") .bind(serde_json::json!({})) .bind(serde_json::json!({})) - .fetch_one(&pool) + .fetch_one(pool) .await?; model }; @@ -148,53 +144,6 @@ impl Model { } Ok(()) } - - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - self.project_info = Some(project_info); - } - - #[instrument(skip(self))] - pub(crate) async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let database_data = self - .database_data - .as_ref() - .context("Model must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "created_at": database_data.created_at, - "name": self.name, - "runtime": Into::<&str>::into(&self.runtime), - "parameters": *self.parameters, - }) - .into()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method model.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } -} - -impl From for Model { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - name: x.model_hyperparams["name"].as_str().unwrap().to_string(), - runtime: x.model_runtime.as_str().into(), - parameters: x.model_hyperparams, - project_info: None, - database_data: Some(ModelDatabaseData { - id: x.model_id, - created_at: x.model_created_at, - }), - } - } } impl From for Model { @@ -203,7 +152,6 @@ impl From for Model { name: model.hyperparams["name"].as_str().unwrap().to_string(), runtime: model.runtime.as_str().into(), parameters: model.hyperparams, - project_info: None, database_data: Some(ModelDatabaseData { id: model.id, created_at: model.created_at, diff --git a/pgml-sdks/pgml/src/models.rs b/pgml-sdks/pgml/src/models.rs index 07440d4e3..e5208d4d8 100644 --- a/pgml-sdks/pgml/src/models.rs +++ b/pgml-sdks/pgml/src/models.rs @@ -5,17 +5,15 @@ use sqlx::FromRow; use crate::types::{DateTime, Json}; -// A pipeline +// A multi field pipeline #[enum_def] #[derive(FromRow)] pub struct Pipeline { pub id: i64, pub name: String, pub created_at: DateTime, - pub model_id: i64, - pub splitter_id: i64, pub active: bool, - pub parameters: Json, + pub schema: Json, } // A model used to perform some task @@ -38,24 +36,6 @@ pub struct Splitter { pub parameters: Json, } -// A pipeline with its model and splitter -#[derive(FromRow, Clone)] -pub struct PipelineWithModelAndSplitter { - pub pipeline_id: i64, - pub pipeline_name: String, - pub pipeline_created_at: DateTime, - pub pipeline_active: bool, - pub pipeline_parameters: Json, - pub model_id: i64, - pub model_created_at: DateTime, - pub model_runtime: String, - pub model_hyperparams: Json, - pub splitter_id: i64, - pub splitter_created_at: DateTime, - pub splitter_name: String, - pub splitter_parameters: Json, -} - // A document #[enum_def] #[derive(FromRow, Serialize)] @@ -65,18 +45,16 @@ pub struct Document { #[serde(with = "uuid::serde::compact")] // See: https://docs.rs/uuid/latest/uuid/serde/index.html pub source_uuid: Uuid, - pub metadata: Json, - pub text: String, + pub document: Json, } impl Document { - pub fn into_user_friendly_json(mut self) -> Json { - self.metadata["text"] = self.text.into(); + pub fn into_user_friendly_json(self) -> Json { serde_json::json!({ "row_id": self.id, "created_at": self.created_at, "source_uuid": self.source_uuid, - "document": self.metadata, + "document": self.document, }) .into() } @@ -109,7 +87,13 @@ pub struct Chunk { pub id: i64, pub created_at: DateTime, pub document_id: i64, - pub splitter_id: i64, pub chunk_index: i64, pub chunk: String, } + +// A tsvector of a document +#[derive(FromRow)] +pub struct TSVector { + pub id: i64, + pub created_at: DateTime, +} diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index dceff4270..6dada5159 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -1,25 +1,139 @@ use anyhow::Context; -use indicatif::MultiProgress; -use rust_bridge::{alias, alias_manual, alias_methods}; -use sqlx::{Executor, PgConnection, PgPool}; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::Relaxed; -use tokio::join; +use rust_bridge::{alias, alias_methods}; +use serde::Deserialize; +use serde_json::json; +use sqlx::{Executor, PgConnection, Pool, Postgres, Transaction}; +use std::collections::HashMap; use tracing::instrument; +use crate::debug_sqlx_query; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, model::{Model, ModelRuntime}, models, queries, query_builder, remote_embeddings::build_remote_embeddings, splitter::Splitter, types::{DateTime, Json, TryToNumeric}, - utils, }; #[cfg(feature = "python")] -use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; +use crate::types::JsonPython; + +type ParsedSchema = HashMap; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidSplitterAction { + model: Option, + parameters: Option, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidEmbedAction { + model: String, + source: Option, + parameters: Option, + hnsw: Option, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(deny_unknown_fields)] +pub struct FullTextSearchAction { + configuration: String, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidFieldAction { + splitter: Option, + semantic_search: Option, + full_text_search: Option, +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct HNSW { + m: u64, + ef_construction: u64, +} + +impl Default for HNSW { + fn default() -> Self { + Self { + m: 16, + ef_construction: 64, + } + } +} + +impl TryFrom for HNSW { + type Error = anyhow::Error; + fn try_from(value: Json) -> anyhow::Result { + let m = if !value["m"].is_null() { + value["m"] + .try_to_u64() + .context("hnsw.m must be an integer")? + } else { + 16 + }; + let ef_construction = if !value["ef_construction"].is_null() { + value["ef_construction"] + .try_to_u64() + .context("hnsw.ef_construction must be an integer")? + } else { + 64 + }; + Ok(Self { m, ef_construction }) + } +} + +#[derive(Debug, Clone)] +pub struct SplitterAction { + pub model: Splitter, +} + +#[derive(Debug, Clone)] +pub struct SemanticSearchAction { + pub model: Model, + pub hnsw: HNSW, +} + +#[derive(Debug, Clone)] +pub struct FieldAction { + pub splitter: Option, + pub semantic_search: Option, + pub full_text_search: Option, +} + +impl TryFrom for FieldAction { + type Error = anyhow::Error; + fn try_from(value: ValidFieldAction) -> Result { + let embed = value + .semantic_search + .map(|v| { + let model = Model::new(Some(v.model), v.source, v.parameters); + let hnsw = v + .hnsw + .map(HNSW::try_from) + .unwrap_or_else(|| Ok(HNSW::default()))?; + anyhow::Ok(SemanticSearchAction { model, hnsw }) + }) + .transpose()?; + let splitter = value + .splitter + .map(|v| { + let splitter = Splitter::new(v.model, v.parameters); + anyhow::Ok(SplitterAction { model: splitter }) + }) + .transpose()?; + Ok(Self { + splitter, + semantic_search: embed, + full_text_search: value.full_text_search, + }) + } +} #[derive(Debug, Clone)] pub struct InvividualSyncStatus { @@ -55,395 +169,525 @@ impl From for InvividualSyncStatus { } } -#[derive(alias_manual, Debug, Clone)] -pub struct PipelineSyncData { - pub chunks_status: InvividualSyncStatus, - pub embeddings_status: InvividualSyncStatus, - pub tsvectors_status: InvividualSyncStatus, -} - -impl From for Json { - fn from(value: PipelineSyncData) -> Self { - serde_json::json!({ - "chunks_status": *Json::from(value.chunks_status), - "embeddings_status": *Json::from(value.embeddings_status), - "tsvectors_status": *Json::from(value.tsvectors_status), - }) - .into() - } -} - -impl From for PipelineSyncData { - fn from(mut value: Json) -> Self { - Self { - chunks_status: Json::from(std::mem::take(&mut value["chunks_status"])).into(), - embeddings_status: Json::from(std::mem::take(&mut value["embeddings_status"])).into(), - tsvectors_status: Json::from(std::mem::take(&mut value["tsvectors_status"])).into(), - } - } -} - #[derive(Debug, Clone)] pub struct PipelineDatabaseData { pub id: i64, pub created_at: DateTime, - pub model_id: i64, - pub splitter_id: i64, } -/// A pipeline that processes documents #[derive(alias, Debug, Clone)] pub struct Pipeline { pub name: String, - pub model: Option, - pub splitter: Option, - pub parameters: Option, - project_info: Option, - pub(crate) database_data: Option, + pub schema: Option, + pub parsed_schema: Option, + database_data: Option, +} + +fn json_to_schema(schema: &Json) -> anyhow::Result { + schema + .as_object() + .context("Schema object must be a JSON object")? + .iter() + .try_fold(ParsedSchema::new(), |mut acc, (key, value)| { + if acc.contains_key(key) { + Err(anyhow::anyhow!("Schema contains duplicate keys")) + } else { + // First lets deserialize it normally + let action: ValidFieldAction = serde_json::from_value(value.to_owned())?; + // Now lets actually build the models and splitters + acc.insert(key.to_owned(), action.try_into()?); + Ok(acc) + } + }) } -#[alias_methods(new, get_status, to_dict)] +#[alias_methods(new)] impl Pipeline { - /// Creates a new [Pipeline] - /// - /// # Arguments - /// - /// * `name` - The name of the pipeline - /// * `model` - The pipeline [Model] - /// * `splitter` - The pipeline [Splitter] - /// * `parameters` - The parameters to the pipeline. Defaults to None - /// - /// # Example - /// - /// ``` - /// use pgml::{Pipeline, Model, Splitter}; - /// let model = Model::new(None, None, None); - /// let splitter = Splitter::new(None, None); - /// let pipeline = Pipeline::new("my_splitter", Some(model), Some(splitter), None); - /// ``` - pub fn new( - name: &str, - model: Option, - splitter: Option, - parameters: Option, - ) -> Self { - let parameters = Some(parameters.unwrap_or_default()); - Self { + pub fn new(name: &str, schema: Option) -> anyhow::Result { + let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?; + Ok(Self { name: name.to_string(), - model, - splitter, - parameters, - project_info: None, + schema, + parsed_schema, database_data: None, - } + }) } /// Gets the status of the [Pipeline] - /// This includes the status of the chunks, embeddings, and tsvectors - /// - /// # Example - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let status = pipeline.get_status().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn get_status(&mut self) -> anyhow::Result { - let pool = self.get_pool().await?; - - self.verify_in_database(false).await?; - let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - let database_data = self - .database_data + pub async fn get_status( + &mut self, + project_info: &ProjectInfo, + pool: &Pool, + ) -> anyhow::Result { + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must be verified to get status")?; + .context("Pipeline must have schema to get status")?; - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to get status")?; + let mut results = json!({}); - let project_name = &self.project_info.as_ref().unwrap().name; + let schema = format!("{}_{}", project_info.name, self.name); + let documents_table_name = format!("{}.documents", project_info.name); + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{schema}.{key}_chunks"); - // TODO: Maybe combine all of these into one query so it is faster - let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s WHERE splitter_id = $1), COUNT(id) FROM %s", - format!("{}.chunks", project_name), - format!("{}.documents", project_name) - )) - .bind(database_data.splitter_id) - .fetch_one(&pool).await?; - let chunks_status = InvividualSyncStatus { - synced: chunks_status.0.unwrap_or(0), - not_synced: chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), - total: chunks_status.1.unwrap_or(0), - }; + results[key] = json!({}); - let embeddings_status: (Option, Option) = sqlx::query_as(&query_builder!( - "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s WHERE splitter_id = $1)", - embeddings_table_name, - format!("{}.chunks", project_name) - )) - .bind(database_data.splitter_id) - .fetch_one(&pool) - .await?; - let embeddings_status = InvividualSyncStatus { - synced: embeddings_status.0.unwrap_or(0), - not_synced: embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), - total: embeddings_status.1.unwrap_or(0), - }; + if value.splitter.is_some() { + let chunks_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT COUNT(DISTINCT document_id) FROM %s), COUNT(id) FROM %s", + chunks_table_name, + documents_table_name + )) + .fetch_one(pool) + .await?; + results[key]["chunks"] = json!({ + "synced": chunks_status.0.unwrap_or(0), + "not_synced": chunks_status.1.unwrap_or(0) - chunks_status.0.unwrap_or(0), + "total": chunks_status.1.unwrap_or(0), + }); + } - let tsvectors_status = if parameters["full_text_search"]["active"] - == serde_json::Value::Bool(true) - { - sqlx::query_as(&query_builder!( - "SELECT (SELECT COUNT(*) FROM %s WHERE configuration = $1), (SELECT COUNT(*) FROM %s)", - format!("{}.documents_tsvectors", project_name), - format!("{}.documents", project_name) - )) - .bind(parameters["full_text_search"]["configuration"].as_str()) - .fetch_one(&pool).await? - } else { - (Some(0), Some(0)) - }; - let tsvectors_status = InvividualSyncStatus { - synced: tsvectors_status.0.unwrap_or(0), - not_synced: tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), - total: tsvectors_status.1.unwrap_or(0), - }; + if value.semantic_search.is_some() { + let embeddings_table_name = format!("{schema}.{key}_embeddings"); + let embeddings_status: (Option, Option) = + sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + embeddings_table_name, + chunks_table_name + )) + .fetch_one(pool) + .await?; + results[key]["embeddings"] = json!({ + "synced": embeddings_status.0.unwrap_or(0), + "not_synced": embeddings_status.1.unwrap_or(0) - embeddings_status.0.unwrap_or(0), + "total": embeddings_status.1.unwrap_or(0), + }); + } - Ok(PipelineSyncData { - chunks_status, - embeddings_status, - tsvectors_status, - }) + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{schema}.{key}_tsvectors"); + let tsvectors_status: (Option, Option) = sqlx::query_as(&query_builder!( + "SELECT (SELECT count(*) FROM %s), (SELECT count(*) FROM %s)", + tsvectors_table_name, + chunks_table_name + )) + .fetch_one(pool) + .await?; + results[key]["tsvectors"] = json!({ + "synced": tsvectors_status.0.unwrap_or(0), + "not_synced": tsvectors_status.1.unwrap_or(0) - tsvectors_status.0.unwrap_or(0), + "total": tsvectors_status.1.unwrap_or(0), + }); + } + } + Ok(results.into()) } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify pipeline without project info"); - let pipeline: Option = sqlx::query_as(&query_builder!( "SELECT * FROM %s WHERE name = $1", format!("{}.pipelines", project_info.name) )) .bind(&self.name) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; - let pipeline = if let Some(p) = pipeline { + let pipeline = if let Some(pipeline) = pipeline { if throw_if_exists { - anyhow::bail!("Pipeline {} already exists", p.name); + anyhow::bail!("Pipeline {} already exists. You do not need to add this pipeline to the collection as it has already been added.", pipeline.name); } - let model: models::Model = sqlx::query_as( - "SELECT id, created_at, runtime::TEXT, hyperparams FROM pgml.models WHERE id = $1", - ) - .bind(p.model_id) - .fetch_one(&pool) - .await?; - let mut model: Model = model.into(); - model.set_project_info(project_info.clone()); - self.model = Some(model); - - let splitter: models::Splitter = - sqlx::query_as("SELECT * FROM pgml.splitters WHERE id = $1") - .bind(p.splitter_id) - .fetch_one(&pool) - .await?; - let mut splitter: Splitter = splitter.into(); - splitter.set_project_info(project_info.clone()); - self.splitter = Some(splitter); - - p + + let mut parsed_schema = json_to_schema(&pipeline.schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { + splitter + .model + .verify_in_database(project_info, false, pool) + .await?; + } + if let Some(embed) = &mut value.semantic_search { + embed + .model + .verify_in_database(project_info, false, pool) + .await?; + } + } + self.schema = Some(pipeline.schema.clone()); + self.parsed_schema = Some(parsed_schema); + + pipeline } else { - let model = self - .model - .as_mut() - .expect("Cannot save pipeline without model"); - model.set_project_info(project_info.clone()); - model.verify_in_database(false).await?; - - let splitter = self - .splitter - .as_mut() - .expect("Cannot save pipeline without splitter"); - splitter.set_project_info(project_info.clone()); - splitter.verify_in_database(false).await?; - - sqlx::query_as(&query_builder!( - "INSERT INTO %s (name, model_id, splitter_id, parameters) VALUES ($1, $2, $3, $4) RETURNING *", - format!("{}.pipelines", project_info.name) - )) - .bind(&self.name) - .bind( - model - .database_data - .as_ref() - .context("Cannot save pipeline without model")? - .id, - ) - .bind( + let schema = self + .schema + .as_ref() + .context("Pipeline must have schema to store in database")?; + let mut parsed_schema = json_to_schema(schema)?; + + for (_key, value) in parsed_schema.iter_mut() { + if let Some(splitter) = &mut value.splitter { splitter - .database_data - .as_ref() - .context("Cannot save pipeline without splitter")? - .id, - ) - .bind(&self.parameters) - .fetch_one(&pool) - .await? - }; + .model + .verify_in_database(project_info, false, pool) + .await?; + } + if let Some(embed) = &mut value.semantic_search { + embed + .model + .verify_in_database(project_info, false, pool) + .await?; + } + } + self.parsed_schema = Some(parsed_schema); + + // Here we actually insert the pipeline into the collection.pipelines table + // and create the collection_pipeline schema and required tables + let mut transaction = pool.begin().await?; + let pipeline = sqlx::query_as(&query_builder!( + "INSERT INTO %s (name, schema) VALUES ($1, $2) RETURNING *", + format!("{}.pipelines", project_info.name) + )) + .bind(&self.name) + .bind(&self.schema) + .fetch_one(&mut *transaction) + .await?; + self.create_tables(project_info, &mut transaction).await?; + transaction.commit().await?; + pipeline + }; self.database_data = Some(PipelineDatabaseData { id: pipeline.id, created_at: pipeline.created_at, - model_id: pipeline.model_id, - splitter_id: pipeline.splitter_id, - }); - self.parameters = Some(pipeline.parameters); + }) } Ok(()) } - #[instrument(skip(self, mp))] - pub(crate) async fn execute( + #[instrument(skip(self))] + async fn create_tables( &mut self, - document_ids: &Option>, - mp: MultiProgress, + project_info: &ProjectInfo, + transaction: &mut Transaction<'_, Postgres>, ) -> anyhow::Result<()> { - // TODO: Chunk document_ids if there are too many - - // A couple notes on the following methods - // - Atomic bools are required to work nicely with pyo3 otherwise we would use cells - // - We use green threads because they are cheap, but we want to be super careful to not - // return an error before stopping the green thread. To meet that end, we map errors and - // return types often - let chunk_ids = self.sync_chunks(document_ids, &mp).await?; - self.sync_embeddings(chunk_ids, &mp).await?; - self.sync_tsvectors(document_ids, &mp).await?; - Ok(()) - } + let collection_name = &project_info.name; + let documents_table_name = format!("{}.documents", collection_name); - #[instrument(skip(self, mp))] - async fn sync_chunks( - &mut self, - document_ids: &Option>, - mp: &MultiProgress, - ) -> anyhow::Result>> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let database_data = self - .database_data - .as_mut() - .context("Pipeline must be verified to generate chunks")?; - - let project_info = self - .project_info + let schema = format!("{}_{}", collection_name, self.name); + + transaction + .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", schema).as_str()) + .await?; + + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must have project info to generate chunks")?; - - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating chunks"); - - // This part is a bit tricky - // We want to return the ids for all chunks we inserted OR would have inserted if they didn't already exist - // The query is structured in such a way to not insert any chunks that already exist so we - // can't rely on the data returned from the inset queries, we need to query the chunks table - // It is important we return the ids for chunks we would have inserted if they didn't already exist so we are robust to random crashes - let is_done = AtomicBool::new(false); - let work = async { - let chunk_ids: Result>, _> = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, - &format!("{}.chunks", project_info.name), - &format!("{}.documents", project_info.name), - &format!("{}.chunks", project_info.name) - )) - .bind(database_data.splitter_id) - .bind(document_ids) - .execute(&pool) - .await - .map_err(|e| { - is_done.store(true, Relaxed); - e - })?; - sqlx::query_scalar(&query_builder!( - "SELECT id FROM %s WHERE document_id = ANY($1)", - &format!("{}.chunks", project_info.name) - )) - .bind(document_ids) - .fetch_all(&pool) - .await - .map(Some) - } else { + .context("Pipeline must have schema to create_tables")?; + + let searches_table_name = format!("{schema}.searches"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCHES_TABLE, + searches_table_name + ) + .as_str(), + ) + .await?; + + let search_results_table_name = format!("{schema}.search_results"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCH_RESULTS_TABLE, + search_results_table_name, + &searches_table_name, + &documents_table_name + ) + .as_str(), + ) + .await?; + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + "search_results_search_id_rank_index", + search_results_table_name, + "search_id, rank" + ) + .as_str(), + ) + .await?; + + let search_events_table_name = format!("{schema}.search_events"); + transaction + .execute( + query_builder!( + queries::CREATE_PIPELINES_SEARCH_EVENTS_TABLE, + search_events_table_name, + &search_results_table_name + ) + .as_str(), + ) + .await?; + + for (key, value) in parsed_schema.iter() { + let chunks_table_name = format!("{}.{}_chunks", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TABLE, + chunks_table_name, + documents_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_chunk_document_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + chunks_table_name, + "document_id" + ) + .as_str(), + ) + .await?; + + if let Some(embed) = &value.semantic_search { + let embeddings_table_name = format!("{}.{}_embeddings", schema, key); + let embedding_length = match &embed.model.runtime { + ModelRuntime::Python => { + let embedding: (Vec,) = sqlx::query_as( + "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") + .bind(&embed.model.name) + .bind(&embed.model.parameters) + .fetch_one(&mut **transaction).await?; + embedding.0.len() as i64 + } + t => { + let remote_embeddings = build_remote_embeddings( + t.to_owned(), + &embed.model.name, + Some(&embed.model.parameters), + )?; + remote_embeddings.get_embedding_size().await? + } + }; + + // Create the embeddings table sqlx::query(&query_builder!( - queries::GENERATE_CHUNKS, - &format!("{}.chunks", project_info.name), - &format!("{}.documents", project_info.name), - &format!("{}.chunks", project_info.name) + queries::CREATE_EMBEDDINGS_TABLE, + &embeddings_table_name, + chunks_table_name, + embedding_length )) - .bind(database_data.splitter_id) - .execute(&pool) - .await - .map(|_t| None) - }; - is_done.store(true, Relaxed); - chunk_ids - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + .execute(&mut **transaction) + .await?; + let index_name = format!("{}_pipeline_embedding_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + &embeddings_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_with_parameters = format!( + "WITH (m = {}, ef_construction = {})", + embed.hnsw.m, embed.hnsw.ef_construction + ); + let index_name = format!("{}_pipeline_embedding_hnsw_vector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &embeddings_table_name, + "embedding vector_cosine_ops", + index_with_parameters + ) + .as_str(), + ) + .await?; } - }; - let (chunk_ids, _) = join!(work, progress_work); - progress_bar.set_message("done generating chunks"); - progress_bar.finish(); - Ok(chunk_ids?) + + // Create the tsvectors table + if value.full_text_search.is_some() { + let tsvectors_table_name = format!("{}.{}_tsvectors", schema, key); + transaction + .execute( + query_builder!( + queries::CREATE_CHUNKS_TSVECTORS_TABLE, + tsvectors_table_name, + chunks_table_name + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_chunk_id_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX, + "", + index_name, + tsvectors_table_name, + "chunk_id" + ) + .as_str(), + ) + .await?; + let index_name = format!("{}_pipeline_tsvector_index", key); + transaction + .execute( + query_builder!( + queries::CREATE_INDEX_USING_GIN, + "", + index_name, + tsvectors_table_name, + "ts" + ) + .as_str(), + ) + .await?; + } + } + Ok(()) } - #[instrument(skip(self, mp))] - async fn sync_embeddings( + #[instrument(skip(self))] + pub(crate) async fn sync_documents( &mut self, - chunk_ids: Option>, - mp: &MultiProgress, + document_ids: Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let embeddings_table_name = self.create_or_get_embeddings_table().await?; - - let model = self - .model + // We are assuming we have manually verified the pipeline before doing this + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must be verified to generate embeddings")?; - - let database_data = self - .database_data - .as_mut() - .context("Pipeline must be verified to generate embeddings")?; + .context("Pipeline must have schema to execute")?; + + for (key, value) in parsed_schema.iter() { + let chunk_ids = self + .sync_chunks_for_documents( + key, + value.splitter.as_ref().map(|v| &v.model), + &document_ids, + project_info, + transaction, + ) + .await?; + if !chunk_ids.is_empty() { + if let Some(embed) = &value.semantic_search { + self.sync_embeddings_for_chunks( + key, + &embed.model, + &chunk_ids, + project_info, + transaction, + ) + .await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.sync_tsvectors_for_chunks( + key, + &full_text_search.configuration, + &chunk_ids, + project_info, + transaction, + ) + .await?; + } + } + } + Ok(()) + } - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to generate embeddings")?; + #[instrument(skip(self))] + async fn sync_chunks_for_documents( + &self, + key: &str, + splitter: Option<&Splitter>, + document_ids: &Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, + ) -> anyhow::Result> { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + let query = query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER, + &json_key_query, + documents_table_name, + &chunks_table_name, + &chunks_table_name, + &chunks_table_name + ); + debug_sqlx_query!( + GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER, + query, + splitter_database_data.id, + document_ids + ); + sqlx::query_scalar(&query) + .bind(splitter_database_data.id) + .bind(document_ids) + .fetch_all(&mut **transaction) + .await + .map_err(anyhow::Error::msg) + } else { + let query = query_builder!( + queries::GENERATE_CHUNKS_FOR_DOCUMENT_IDS, + &chunks_table_name, + &json_key_query, + &documents_table_name, + &chunks_table_name, + &json_key_query + ); + debug_sqlx_query!(GENERATE_CHUNKS_FOR_DOCUMENT_IDS, query, document_ids); + sqlx::query_scalar(&query) + .bind(document_ids) + .fetch_all(&mut **transaction) + .await + .map_err(anyhow::Error::msg) + } + } + #[instrument(skip(self))] + async fn sync_embeddings_for_chunks( + &self, + key: &str, + model: &Model, + chunk_ids: &Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, + ) -> anyhow::Result<()> { // Remove the stored name from the parameters let mut parameters = model.parameters.clone(); parameters @@ -451,370 +695,248 @@ impl Pipeline { .context("Model parameters must be an object")? .remove("name"); - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating emmbeddings"); - - let is_done = AtomicBool::new(false); - // We need to be careful about how we handle errors here. We do not want to return an error - // from the async block before setting is_done to true. If we do, the progress bar will - // will load forever. We also want to make sure to propogate any errors we have - let work = async { - let res = match model.runtime { - ModelRuntime::Python => if chunk_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, - embeddings_table_name, - &format!("{}.chunks", project_info.name), - embeddings_table_name - )) + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + match model.runtime { + ModelRuntime::Python => { + let query = query_builder!( + queries::GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + embeddings_table_name, + chunks_table_name + ); + debug_sqlx_query!( + GENERATE_EMBEDDINGS_FOR_CHUNK_IDS, + query, + model.name, + parameters.0, + chunk_ids + ); + sqlx::query(&query) .bind(&model.name) .bind(¶meters) - .bind(database_data.splitter_id) .bind(chunk_ids) - .execute(&pool) - .await - } else { - sqlx::query(&query_builder!( - queries::GENERATE_EMBEDDINGS, - embeddings_table_name, - &format!("{}.chunks", project_info.name), - embeddings_table_name - )) - .bind(&model.name) - .bind(¶meters) - .bind(database_data.splitter_id) - .execute(&pool) - .await - } - .map_err(|e| anyhow::anyhow!(e)) - .map(|_t| ()), - r => { - let remote_embeddings = build_remote_embeddings(r, &model.name, ¶meters)?; - remote_embeddings - .generate_embeddings( - &embeddings_table_name, - &format!("{}.chunks", project_info.name), - database_data.splitter_id, - chunk_ids, - &pool, - ) - .await - .map(|_t| ()) - } - }; - is_done.store(true, Relaxed); - res - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + .execute(&mut **transaction) + .await?; } - }; - let (res, _) = join!(work, progress_work); - progress_bar.set_message("done generating embeddings"); - progress_bar.finish(); - res + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + Some(chunk_ids), + transaction, + ) + .await?; + } + } + Ok(()) } #[instrument(skip(self))] - async fn sync_tsvectors( - &mut self, - document_ids: &Option>, - mp: &MultiProgress, + async fn sync_tsvectors_for_chunks( + &self, + key: &str, + configuration: &str, + chunk_ids: &Vec, + project_info: &ProjectInfo, + transaction: &mut Transaction<'static, Postgres>, ) -> anyhow::Result<()> { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to generate tsvectors")?; - - if parameters["full_text_search"]["active"] != serde_json::Value::Bool(true) { - return Ok(()); - } - - let project_info = self - .project_info - .as_ref() - .context("Pipeline must have project info to generate tsvectors")?; - - let progress_bar = mp - .add(utils::default_progress_spinner(1)) - .with_prefix(self.name.clone()) - .with_message("generating tsvectors for full text search"); - - let configuration = parameters["full_text_search"]["configuration"] - .as_str() - .context("Full text search configuration must be a string")?; - - let is_done = AtomicBool::new(false); - let work = async { - let res = if document_ids.is_some() { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS_FOR_DOCUMENT_IDS, - format!("{}.documents_tsvectors", project_info.name), - configuration, - configuration, - format!("{}.documents", project_info.name) - )) - .bind(document_ids) - .execute(&pool) - .await - } else { - sqlx::query(&query_builder!( - queries::GENERATE_TSVECTORS, - format!("{}.documents_tsvectors", project_info.name), - configuration, - configuration, - format!("{}.documents", project_info.name) - )) - .execute(&pool) - .await - }; - is_done.store(true, Relaxed); - res.map(|_t| ()).map_err(|e| anyhow::anyhow!(e)) - }; - let progress_work = async { - while !is_done.load(Relaxed) { - progress_bar.inc(1); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - } - }; - let (res, _) = join!(work, progress_work); - progress_bar.set_message("done generating tsvectors for full text search"); - progress_bar.finish(); - res + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + let query = query_builder!( + queries::GENERATE_TSVECTORS_FOR_CHUNK_IDS, + tsvectors_table_name, + configuration, + chunks_table_name + ); + debug_sqlx_query!(GENERATE_TSVECTORS_FOR_CHUNK_IDS, query, chunk_ids); + sqlx::query(&query) + .bind(chunk_ids) + .execute(&mut **transaction) + .await?; + Ok(()) } #[instrument(skip(self))] - pub(crate) async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - let pool = self.get_pool().await?; - - let collection_name = &self - .project_info + pub(crate) async fn resync( + &mut self, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + // We are assuming we have manually verified the pipeline before doing this + let parsed_schema = self + .parsed_schema .as_ref() - .context("Pipeline must have project info to get the embeddings table name")? - .name; - let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); - - // Notice that we actually check for existence of the table in the database instead of - // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid - // generating embeddings just to get the length if we don't need to - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)" + .context("Pipeline must have schema to execute")?; + // Before doing any syncing, delete all old and potentially outdated documents + for (key, _value) in parsed_schema.iter() { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + connection + .execute(query_builder!("DELETE FROM %s CASCADE", chunks_table_name).as_str()) + .await?; + } + for (key, value) in parsed_schema.iter() { + self.resync_chunks( + key, + value.splitter.as_ref().map(|v| &v.model), + project_info, + connection, ) - .bind(&self - .project_info - .as_ref() - .context("Pipeline must have project info to get the embeddings table name")?.name) - .bind(format!("{}_embeddings", self.name)).fetch_one(&pool).await?; - - if !exists { - let model = self - .model - .as_ref() - .context("Pipeline must be verified to create embeddings table")?; - - // Remove the stored name from the model parameters - let mut model_parameters = model.parameters.clone(); - model_parameters - .as_object_mut() - .context("Model parameters must be an object")? - .remove("name"); - - let embedding_length = match &model.runtime { - ModelRuntime::Python => { - let embedding: (Vec,) = sqlx::query_as( - "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") - .bind(&model.name) - .bind(model_parameters) - .fetch_one(&pool).await?; - embedding.0.len() as i64 - } - t => { - let remote_embeddings = - build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; - remote_embeddings.get_embedding_size().await? - } - }; - - let mut transaction = pool.begin().await?; - sqlx::query(&query_builder!( - queries::CREATE_EMBEDDINGS_TABLE, - &embeddings_table_name, - &format!( - "{}.chunks", - self.project_info - .as_ref() - .context("Pipeline must have project info to create the embeddings table")? - .name - ), - embedding_length - )) - .execute(&mut *transaction) .await?; - let index_name = format!("{}_pipeline_created_at_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "created_at" - ) - .as_str(), - ) - .await?; - let index_name = format!("{}_pipeline_chunk_id_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX, - "", - index_name, - &embeddings_table_name, - "chunk_id" - ) - .as_str(), - ) - .await?; - // See: https://github.com/pgvector/pgvector - let (m, ef_construction) = match &self.parameters { - Some(p) => { - let m = if !p["hnsw"]["m"].is_null() { - p["hnsw"]["m"] - .try_to_u64() - .context("hnsw.m must be an integer")? - } else { - 16 - }; - let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { - p["hnsw"]["ef_construction"] - .try_to_u64() - .context("hnsw.ef_construction must be an integer")? - } else { - 64 - }; - (m, ef_construction) - } - None => (16, 64), - }; - let index_with_parameters = - format!("WITH (m = {}, ef_construction = {})", m, ef_construction); - let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); - transaction - .execute( - query_builder!( - queries::CREATE_INDEX_USING_HNSW, - "", - index_name, - &embeddings_table_name, - "embedding vector_cosine_ops", - index_with_parameters - ) - .as_str(), + if let Some(embed) = &value.semantic_search { + self.resync_embeddings(key, &embed.model, project_info, connection) + .await?; + } + if let Some(full_text_search) = &value.full_text_search { + self.resync_tsvectors( + key, + &full_text_search.configuration, + project_info, + connection, ) .await?; - transaction.commit().await?; + } } - - Ok(embeddings_table_name) + Ok(()) } #[instrument(skip(self))] - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - if self.model.is_some() { - self.model - .as_mut() - .unwrap() - .set_project_info(project_info.clone()); - } - if self.splitter.is_some() { - self.splitter - .as_mut() - .unwrap() - .set_project_info(project_info.clone()); + async fn resync_chunks( + &self, + key: &str, + splitter: Option<&Splitter>, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let documents_table_name = format!("{}.documents", project_info.name); + let json_key_query = format!("document->>'{}'", key); + + if let Some(splitter) = splitter { + let splitter_database_data = splitter + .database_data + .as_ref() + .context("Splitter must be verified to sync chunks")?; + let query = query_builder!( + queries::GENERATE_CHUNKS_WITH_SPLITTER, + &json_key_query, + &documents_table_name, + &chunks_table_name, + &chunks_table_name + ); + debug_sqlx_query!( + GENERATE_CHUNKS_WITH_SPLITTER, + query, + splitter_database_data.id + ); + sqlx::query(&query) + .bind(splitter_database_data.id) + .execute(connection) + .await?; + } else { + let query = query_builder!( + queries::GENERATE_CHUNKS, + &chunks_table_name, + &json_key_query, + &documents_table_name + ); + debug_sqlx_query!(GENERATE_CHUNKS, query); + sqlx::query(&query).execute(connection).await?; } - self.project_info = Some(project_info); + Ok(()) } - /// Convert the [Pipeline] to [Json] - /// - /// # Example: - /// - /// ``` - /// use pgml::Collection; - /// - /// async fn example() -> anyhow::Result<()> { - /// let mut collection = Collection::new("my_collection", None); - /// let mut pipeline = collection.get_pipeline("my_pipeline").await?; - /// let pipeline_dict = pipeline.to_dict().await?; - /// Ok(()) - /// } - /// ``` #[instrument(skip(self))] - pub async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let status = self.get_status().await?; - - let model_dict = self - .model - .as_mut() - .context("Pipeline must be verified to call to_dict")? - .to_dict() - .await?; - - let splitter_dict = self - .splitter - .as_mut() - .context("Pipeline must be verified to call to_dict")? - .to_dict() - .await?; + async fn resync_embeddings( + &self, + key: &str, + model: &Model, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + // Remove the stored name from the parameters + let mut parameters = model.parameters.clone(); + parameters + .as_object_mut() + .context("Model parameters must be an object")? + .remove("name"); - let database_data = self - .database_data - .as_ref() - .context("Pipeline must be verified to call to_dict")?; + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let embeddings_table_name = + format!("{}_{}.{}_embeddings", project_info.name, self.name, key); + + match model.runtime { + ModelRuntime::Python => { + let query = query_builder!( + queries::GENERATE_EMBEDDINGS, + embeddings_table_name, + chunks_table_name + ); + debug_sqlx_query!(GENERATE_EMBEDDINGS, query, model.name, parameters.0); + sqlx::query(&query) + .bind(&model.name) + .bind(¶meters) + .execute(connection) + .await?; + } + r => { + let remote_embeddings = build_remote_embeddings(r, &model.name, Some(¶meters))?; + remote_embeddings + .generate_embeddings( + &embeddings_table_name, + &chunks_table_name, + None, + connection, + ) + .await?; + } + } + Ok(()) + } - let parameters = self - .parameters - .as_ref() - .context("Pipeline must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "name": self.name, - "model": *model_dict, - "splitter": *splitter_dict, - "parameters": *parameters, - "status": *Json::from(status), - }) - .into()) + #[instrument(skip(self))] + async fn resync_tsvectors( + &self, + key: &str, + configuration: &str, + project_info: &ProjectInfo, + connection: &mut PgConnection, + ) -> anyhow::Result<()> { + let chunks_table_name = format!("{}_{}.{}_chunks", project_info.name, self.name, key); + let tsvectors_table_name = format!("{}_{}.{}_tsvectors", project_info.name, self.name, key); + + let query = query_builder!( + queries::GENERATE_TSVECTORS, + tsvectors_table_name, + configuration, + chunks_table_name + ); + debug_sqlx_query!(GENERATE_TSVECTORS, query); + sqlx::query(&query).execute(connection).await?; + Ok(()) } - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method pipeline.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await + #[instrument(skip(self))] + pub(crate) async fn get_parsed_schema( + &mut self, + project_info: &ProjectInfo, + pool: &Pool, + ) -> anyhow::Result { + self.verify_in_database(project_info, false, pool).await?; + Ok(self.parsed_schema.as_ref().unwrap().clone()) } + #[instrument] pub(crate) async fn create_pipelines_table( project_info: &ProjectInfo, conn: &mut PgConnection, ) -> anyhow::Result<()> { let pipelines_table_name = format!("{}.pipelines", project_info.name); sqlx::query(&query_builder!( - queries::CREATE_PIPELINES_TABLE, + queries::PIPELINES_TABLE, pipelines_table_name )) .execute(&mut *conn) @@ -834,20 +956,17 @@ impl Pipeline { } } -impl From for Pipeline { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - model: Some(x.clone().into()), - splitter: Some(x.clone().into()), - name: x.pipeline_name, - project_info: None, - database_data: Some(PipelineDatabaseData { - id: x.pipeline_id, - created_at: x.pipeline_created_at, - model_id: x.model_id, - splitter_id: x.splitter_id, - }), - parameters: Some(x.pipeline_parameters), - } +impl TryFrom for Pipeline { + type Error = anyhow::Error; + fn try_from(value: models::Pipeline) -> anyhow::Result { + let parsed_schema = json_to_schema(&value.schema).unwrap(); + // NOTE: We do not set the database data here even though we have it + // self.verify_in_database() also verifies all models in the schema so we don't want to set it here + Ok(Self { + name: value.name, + schema: Some(value.schema), + parsed_schema: Some(parsed_schema), + database_data: None, + }) } } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 8e793691e..1ea7001bf 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -1,6 +1,7 @@ ///////////////////////////// // CREATE TABLE QUERIES ///// ///////////////////////////// + pub const CREATE_COLLECTIONS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS pgml.collections ( id serial8 PRIMARY KEY, @@ -13,15 +14,13 @@ CREATE TABLE IF NOT EXISTS pgml.collections ( ); "#; -pub const CREATE_PIPELINES_TABLE: &str = r#" +pub const PIPELINES_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, name text NOT NULL, created_at timestamp NOT NULL DEFAULT now(), - model_id int8 NOT NULL REFERENCES pgml.models ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - splitter_id int8 NOT NULL REFERENCES pgml.splitters ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, active BOOLEAN NOT NULL DEFAULT TRUE, - parameters jsonb NOT NULL DEFAULT '{}', + schema jsonb NOT NULL, UNIQUE (name) ); "#; @@ -31,8 +30,8 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), source_uuid uuid NOT NULL, - metadata jsonb NOT NULL DEFAULT '{}', - text text NOT NULL, + document jsonb NOT NULL, + version jsonb NOT NULL DEFAULT '{}'::jsonb, UNIQUE (source_uuid) ); "#; @@ -50,10 +49,9 @@ CREATE TABLE IF NOT EXISTS pgml.splitters ( pub const CREATE_CHUNKS_TABLE: &str = r#"CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - splitter_id int8 NOT NULL REFERENCES pgml.splitters ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, chunk_index int8 NOT NULL, chunk text NOT NULL, - UNIQUE (document_id, splitter_id, chunk_index) + UNIQUE (document_id, chunk_index) ); "#; @@ -67,20 +65,47 @@ CREATE TABLE IF NOT EXISTS %s ( ); "#; -pub const CREATE_DOCUMENTS_TSVECTORS_TABLE: &str = r#" +pub const CREATE_CHUNKS_TSVECTORS_TABLE: &str = r#" CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), - document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, - configuration text NOT NULL, + chunk_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, ts tsvector, - UNIQUE (configuration, document_id) + UNIQUE (chunk_id) +); +"#; + +pub const CREATE_PIPELINES_SEARCHES_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + created_at timestamp NOT NULL DEFAULT now(), + query jsonb +); +"#; + +pub const CREATE_PIPELINES_SEARCH_RESULTS_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + search_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + document_id int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + scores jsonb NOT NULL, + rank integer NOT NULL +); +"#; + +pub const CREATE_PIPELINES_SEARCH_EVENTS_TABLE: &str = r#" +CREATE TABLE IF NOT EXISTS %s ( + id serial8 PRIMARY KEY, + created_at timestamp NOT NULL DEFAULT now(), + search_result int8 NOT NULL REFERENCES %s ON DELETE CASCADE, + event jsonb NOT NULL ); "#; ///////////////////////////// // CREATE INDICES /////////// ///////////////////////////// + pub const CREATE_INDEX: &str = r#" CREATE INDEX %d IF NOT EXISTS %s ON %s (%d); "#; @@ -94,32 +119,102 @@ CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; "#; ///////////////////////////// -// Other Big Queries //////// +// Inserting Search Events // ///////////////////////////// -pub const GENERATE_TSVECTORS: &str = r#" -INSERT INTO %s (document_id, configuration, ts) + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user calls collection.add_search_event +// Required indexes: +// search_results table | "search_results_search_id_rank_index" btree (search_id, rank) +// Used to insert a search event +pub const INSERT_SEARCH_EVENT: &str = r#" +INSERT INTO %s (search_result, event) VALUES ((SELECT id FROM %s WHERE search_id = $1 AND rank = $2), $3) +"#; + +///////////////////////////// +// Upserting Documents ////// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user upserts documents +// Required indexes: +// documents table | - "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) +// Used to upsert a document and merge the previous metadata on conflict +// The values of the query and the source_uuid binding are built when used +pub const UPSERT_DOCUMENT_AND_MERGE_METADATA: &str = r#" +WITH prev AS ( + SELECT id, document FROM %s WHERE source_uuid = ANY({binding_parameter}) +) INSERT INTO %s (source_uuid, document, version) +VALUES {values_parameters} +ON CONFLICT (source_uuid) DO UPDATE SET document = %s.document || EXCLUDED.document, version = EXCLUDED.version +RETURNING id, (SELECT document FROM prev WHERE prev.id = %s.id) +"#; + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a user upserts documents +// Required indexes: +// - documents table | "documents_source_uuid_key" UNIQUE CONSTRAINT, btree (source_uuid) +// Used to upsert a document and over the previous document on conflict +// The values of the query and the source_uuid binding are built when used +pub const UPSERT_DOCUMENT: &str = r#" +WITH prev AS ( + SELECT id, document FROM %s WHERE source_uuid = ANY({binding_parameter}) +) INSERT INTO %s (source_uuid, document, version) +VALUES {values_parameters} +ON CONFLICT (source_uuid) DO UPDATE SET document = EXCLUDED.document, version = EXCLUDED.version +RETURNING id, (SELECT document FROM prev WHERE prev.id = %s.id) +"#; + +///////////////////////////// +// Generaiting TSVectors //// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is syncing documents and does full_text_search +// Required indexes: +// - chunks table | "{key}_tsvectors_pkey" PRIMARY KEY, btree (id) +// Used to generate tsvectors for specific chunks +pub const GENERATE_TSVECTORS_FOR_CHUNK_IDS: &str = r#" +INSERT INTO %s (chunk_id, ts) SELECT id, - '%d' configuration, - to_tsvector('%d', text) ts + to_tsvector('%d', chunk) ts FROM %s -ON CONFLICT (document_id, configuration) DO UPDATE SET ts = EXCLUDED.ts; +WHERE id = ANY ($1) +ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; -pub const GENERATE_TSVECTORS_FOR_DOCUMENT_IDS: &str = r#" -INSERT INTO %s (document_id, configuration, ts) +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is resyncing and does full_text_search +// Required indexes: None +// Used to generate tsvectors for an entire collection +pub const GENERATE_TSVECTORS: &str = r#" +INSERT INTO %s (chunk_id, ts) SELECT id, - '%d' configuration, - to_tsvector('%d', text) ts + to_tsvector('%d', chunk) ts FROM - %s -WHERE id = ANY ($1) -ON CONFLICT (document_id, configuration) DO NOTHING; + %s chunks +ON CONFLICT (chunk_id) DO UPDATE SET ts = EXCLUDED.ts; "#; -pub const GENERATE_EMBEDDINGS: &str = r#" +///////////////////////////// +// Generaiting Embeddings /// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenver a pipeline is syncing documents and does semantic_search +// Required indexes: +// - chunks table | "{key}_chunks_pkey" PRIMARY KEY, btree (id) +// Used to generate embeddings for specific chunks +pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" INSERT INTO %s (chunk_id, embedding) SELECT id, @@ -131,17 +226,16 @@ SELECT FROM %s WHERE - splitter_id = $3 - AND id NOT IN ( - SELECT - chunk_id - from - %s - ) -ON CONFLICT (chunk_id) DO NOTHING; + id = ANY ($3) +ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding "#; -pub const GENERATE_EMBEDDINGS_FOR_CHUNK_IDS: &str = r#" +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenever a pipeline is resyncing and does semantic_search +// Required indexes: None +// Used to generate embeddings for an entire collection +pub const GENERATE_EMBEDDINGS: &str = r#" INSERT INTO %s (chunk_id, embedding) SELECT id, @@ -152,169 +246,166 @@ SELECT ) FROM %s -WHERE - splitter_id = $3 - AND id = ANY ($4) - AND id NOT IN ( - SELECT - chunk_id - from - %s - ) -ON CONFLICT (chunk_id) DO NOTHING; +ON CONFLICT (chunk_id) DO UPDATE set embedding = EXCLUDED.embedding; "#; -pub const EMBED_AND_VECTOR_SEARCH: &str = r#" -WITH pipeline AS ( +///////////////////////////// +// Generating Chunks /////// +///////////////////////////// + +// Tag: CRITICAL_QUERY +// Checked: False +// Used to generate chunks for a specific documents with a splitter +pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS_WITH_SPLITTER: &str = r#" +WITH splitter AS ( SELECT - model_id + name, + parameters FROM - %s + pgml.splitters WHERE - name = $1 + id = $1 ), -model AS ( +new AS ( SELECT - hyperparams - FROM - pgml.models + documents.id AS document_id, + pgml.chunk (( + SELECT + name + FROM splitter), %d, ( + SELECT + parameters + FROM splitter)) AS chunk_t +FROM + %s AS documents WHERE - id = (SELECT model_id FROM pipeline) + id = ANY ($2) ), -embedding AS ( - SELECT - pgml.embed( - transformer => (SELECT hyperparams->>'name' FROM model), - text => $2, - kwargs => $3 - )::vector AS embedding -) -SELECT - embeddings.embedding <=> (SELECT embedding FROM embedding) score, - chunks.chunk, - documents.metadata -FROM - %s embeddings - INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id - INNER JOIN %s documents ON documents.id = chunks.document_id - ORDER BY - score ASC - LIMIT - $4; -"#; - -pub const VECTOR_SEARCH: &str = r#" -SELECT - embeddings.embedding <=> $1::vector score, - chunks.chunk, - documents.metadata -FROM - %s embeddings - INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id - INNER JOIN %s documents ON documents.id = chunks.document_id - ORDER BY - score ASC - LIMIT - $2; +del AS ( + DELETE FROM %s chunks + WHERE chunk_index > ( + SELECT + MAX((chunk_t).chunk_index) + FROM + new + WHERE + new.document_id = chunks.document_id + GROUP BY + new.document_id) + AND chunks.document_id = ANY ( + SELECT + document_id + FROM + new)) + INSERT INTO %s (document_id, chunk_index, chunk) +SELECT + new.document_id, + (chunk_t).chunk_index, + (chunk_t).chunk +FROM + new + LEFT OUTER JOIN %s chunks ON chunks.document_id = new.document_id + AND chunks.chunk_index = (chunk_t).chunk_index +WHERE (chunk_t).chunk <> COALESCE(chunks.chunk, '') +ON CONFLICT (document_id, chunk_index) + DO UPDATE SET + chunk = EXCLUDED.chunk +RETURNING + id; "#; -pub const GENERATE_CHUNKS: &str = r#" -WITH splitter as ( - SELECT - name, - parameters - FROM - pgml.splitters - WHERE - id = $1 -) +// Tag: CRITICAL_QUERY +// Checked: True +// Trigger: Runs whenver a pipeline is syncing documents and the key does not have a splitter +// Required indexes: +// - documents table | "documents_pkey" PRIMARY KEY, btree (id) +// - chunks table | "{key}_pipeline_chunk_document_id_index" btree (document_id) +// Used to generate chunks for a specific documents without a splitter +// This query just copies the document key into the chunk +pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" INSERT INTO %s( - document_id, splitter_id, chunk_index, - chunk -) + document_id, chunk_index, chunk +) SELECT - document_id, - $1, - (chunk).chunk_index, - (chunk).chunk -FROM - ( - select - id AS document_id, - pgml.chunk( - (SELECT name FROM splitter), - text, - (SELECT parameters FROM splitter) - ) AS chunk - FROM - ( - SELECT - id, - text - FROM - %s - WHERE - id NOT IN ( - SELECT - document_id - FROM - %s - WHERE - splitter_id = $1 - ) - ) AS documents - ) chunks -ON CONFLICT (document_id, splitter_id, chunk_index) DO NOTHING + documents.id, + 1, + %d +FROM %s documents +LEFT OUTER JOIN %s chunks ON chunks.document_id = documents.id +WHERE documents.%d <> COALESCE(chunks.chunk, '') + AND documents.id = ANY($1) +ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk RETURNING id "#; -pub const GENERATE_CHUNKS_FOR_DOCUMENT_IDS: &str = r#" -WITH splitter as ( +// Tag: CRITICAL_QUERY +// Checked: False +// Used to generate chunks for an entire collection with a splitter +pub const GENERATE_CHUNKS_WITH_SPLITTER: &str = r#" +WITH splitter AS ( SELECT - name, - parameters + name, + parameters FROM - pgml.splitters + pgml.splitters WHERE - id = $1 -) -INSERT INTO %s( - document_id, splitter_id, chunk_index, - chunk + id = $1 +), +new AS ( + SELECT + documents.id AS document_id, + pgml.chunk (( + SELECT + name + FROM splitter), %d, ( + SELECT + parameters + FROM splitter)) AS chunk_t +FROM + %s AS documents +), +del AS ( + DELETE FROM %s chunks + WHERE chunk_index > ( + SELECT + MAX((chunk_t).chunk_index) + FROM + new + WHERE + new.document_id = chunks.document_id + GROUP BY + new.document_id) + AND chunks.document_id = ANY ( + SELECT + document_id + FROM + new)) +INSERT INTO %s (document_id, chunk_index, chunk) +SELECT + new.document_id, + (chunk_t).chunk_index, + (chunk_t).chunk +FROM + new +ON CONFLICT (document_id, chunk_index) + DO UPDATE SET + chunk = EXCLUDED.chunk; +"#; + +// Tag: CRITICAL_QUERY +// Trigger: Runs whenever a pipeline is resyncing +// Required indexes: None +// Checked: True +// Used to generate chunks for an entire collection +pub const GENERATE_CHUNKS: &str = r#" +INSERT INTO %s ( + document_id, chunk_index, chunk ) -SELECT - document_id, - $1, - (chunk).chunk_index, - (chunk).chunk -FROM - ( - select - id AS document_id, - pgml.chunk( - (SELECT name FROM splitter), - text, - (SELECT parameters FROM splitter) - ) AS chunk - FROM - ( - SELECT - id, - text - FROM - %s - WHERE - id = ANY($2) - AND id NOT IN ( - SELECT - document_id - FROM - %s - WHERE - splitter_id = $1 - ) - ) AS documents - ) chunks -ON CONFLICT (document_id, splitter_id, chunk_index) DO NOTHING +SELECT + id, + 1, + %d +FROM %s +ON CONFLICT (document_id, chunk_index) DO UPDATE SET chunk = EXCLUDED.chunk RETURNING id "#; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 98fbe104a..4250f9db1 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -1,56 +1,47 @@ +// NOTE: DEPRECATED +// This whole file is legacy and is only here to be backwards compatible with collection.query() +// No new things should be added here, instead add new items to collection.vector_search + use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sea_query::{ - query::SelectStatement, Alias, CommonTableExpression, Expr, Func, JoinType, Order, - PostgresQueryBuilder, Query, QueryStatementWriter, WithClause, -}; -use sea_query_binder::SqlxBinder; -use std::borrow::Cow; +use serde_json::json; use tracing::instrument; -use crate::{ - filter_builder, get_or_initialize_pool, - model::ModelRuntime, - models, - pipeline::Pipeline, - query_builder, - remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, - Collection, -}; +use crate::{pipeline::Pipeline, types::Json, Collection}; #[cfg(feature = "python")] use crate::{pipeline::PipelinePython, types::JsonPython}; -#[derive(Clone, Debug)] -struct QueryBuilderState {} - #[derive(alias, Clone, Debug)] pub struct QueryBuilder { - query: SelectStatement, - with: WithClause, collection: Collection, - query_string: Option, + query: Json, pipeline: Option, - query_parameters: Option, } #[alias_methods(limit, filter, vector_recall, to_full_string, fetch_all)] impl QueryBuilder { pub fn new(collection: Collection) -> Self { + let query = json!({ + "query": { + "fields": { + "text": { + + } + } + } + }) + .into(); Self { - query: SelectStatement::new(), - with: WithClause::new(), collection, - query_string: None, + query, pipeline: None, - query_parameters: None, } } #[instrument(skip(self))] pub fn limit(mut self, limit: u64) -> Self { - self.query.limit(limit); + self.query["limit"] = json!(limit); self } @@ -61,62 +52,15 @@ impl QueryBuilder { .as_object_mut() .expect("Filter must be a Json object"); if let Some(f) = filter.remove("metadata") { - self = self.filter_metadata(f); + self.query["query"]["filter"] = f; } - if let Some(f) = filter.remove("full_text_search") { - self = self.filter_full_text(f); + if let Some(mut f) = filter.remove("full_text") { + self.query["query"]["fields"]["text"]["full_text_filter"] = + std::mem::take(&mut f["text"]); } self } - #[instrument(skip(self))] - fn filter_metadata(mut self, filter: serde_json::Value) -> Self { - let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata").build(); - self.query.cond_where(filter); - self - } - - #[instrument(skip(self))] - fn filter_full_text(mut self, mut filter: serde_json::Value) -> Self { - let filter = filter - .as_object_mut() - .expect("Full text filter must be a Json object"); - let configuration = match filter.get("configuration") { - Some(config) => config.as_str().expect("Configuration must be a string"), - None => "english", - }; - let filter_text = filter - .get("text") - .expect("Filter must contain a text field") - .as_str() - .expect("Text must be a string"); - self.query - .join_as( - JoinType::InnerJoin, - self.collection - .documents_tsvectors_table_name - .to_table_tuple(), - Alias::new("documents_tsvectors"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("documents_tsvectors"), SIden::Str("document_id"))), - ) - .and_where( - Expr::col(( - SIden::Str("documents_tsvectors"), - SIden::Str("configuration"), - )) - .eq(configuration), - ) - .and_where(Expr::cust_with_values( - format!( - "documents_tsvectors.ts @@ plainto_tsquery('{}', $1)", - configuration - ), - [filter_text], - )); - self - } - #[instrument(skip(self))] pub fn vector_recall( mut self, @@ -124,221 +68,37 @@ impl QueryBuilder { pipeline: &Pipeline, query_parameters: Option, ) -> Self { - // Save these in case of failure self.pipeline = Some(pipeline.clone()); - self.query_string = Some(query.to_owned()); - self.query_parameters = query_parameters.clone(); - - let mut query_parameters = query_parameters.unwrap_or_default().0; - // If they did set hnsw, remove it before we pass it to the model - query_parameters - .as_object_mut() - .expect("Query parameters must be a Json object") - .remove("hnsw"); - let embeddings_table_name = - format!("{}.{}_embeddings", self.collection.name, pipeline.name); - - // Build the pipeline CTE - let mut pipeline_cte = Query::select(); - pipeline_cte - .from_as( - self.collection.pipelines_table_name.to_table_tuple(), - SIden::Str("pipeline"), - ) - .columns([models::PipelineIden::ModelId]) - .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); - let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); - pipeline_cte.table_name(Alias::new("pipeline")); - - // Build the model CTE - let mut model_cte = Query::select(); - model_cte - .from_as( - (SIden::Str("pgml"), SIden::Str("models")), - SIden::Str("model"), - ) - .columns([models::ModelIden::Hyperparams]) - .and_where(Expr::cust("id = (SELECT model_id FROM pipeline)")); - let mut model_cte = CommonTableExpression::from_select(model_cte); - model_cte.table_name(Alias::new("model")); - - // Build the embedding CTE - let mut embedding_cte = Query::select(); - embedding_cte.expr_as( - Func::cast_as( - Func::cust(SIden::Str("pgml.embed")).args([ - Expr::cust("transformer => (SELECT hyperparams->>'name' FROM model)"), - Expr::cust_with_values("text => $1", [query]), - Expr::cust_with_values("kwargs => $1", [query_parameters]), - ]), - Alias::new("vector"), - ), - Alias::new("embedding"), - ); - let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new("embedding")); - - // Build the where clause - let mut with_clause = WithClause::new(); - self.with = with_clause - .cte(pipeline_cte) - .cte(model_cte) - .cte(embedding_cte) - .to_owned(); - - // Build the query - self.query - .expr(Expr::cust( - "(embeddings.embedding <=> (SELECT embedding from embedding)) score", - )) - .columns([ - (SIden::Str("chunks"), SIden::Str("chunk")), - (SIden::Str("documents"), SIden::Str("metadata")), - ]) - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .join_as( - JoinType::InnerJoin, - self.collection.chunks_table_name.to_table_tuple(), - Alias::new("chunks"), - Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), - ) - .join_as( - JoinType::InnerJoin, - self.collection.documents_table_name.to_table_tuple(), - Alias::new("documents"), - Expr::col((SIden::Str("documents"), SIden::Str("id"))) - .equals((SIden::Str("chunks"), SIden::Str("document_id"))), - ) - .order_by(SIden::Str("score"), Order::Asc); - + self.query["query"]["fields"]["text"]["query"] = json!(query); + if let Some(query_parameters) = query_parameters { + self.query["query"]["fields"]["text"]["model_parameters"] = query_parameters.0; + } self } #[instrument(skip(self))] pub async fn fetch_all(mut self) -> anyhow::Result> { - let pool = get_or_initialize_pool(&self.collection.database_url).await?; - - let mut query_parameters = self.query_parameters.unwrap_or_default(); - - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); - - let result: Result, _> = - if !query_parameters["hnsw"]["ef_search"].is_null() { - let mut transaction = pool.begin().await?; - let ef_search = query_parameters["hnsw"]["ef_search"] - .try_to_i64() - .context("ef_search must be an integer")?; - sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) - .execute(&mut *transaction) - .await?; - let results = sqlx::query_as_with(&sql, values) - .fetch_all(&mut *transaction) - .await; - transaction.commit().await?; - results - } else { - sqlx::query_as_with(&sql, values).fetch_all(&pool).await - }; - - match result { - Ok(r) => Ok(r), - Err(e) => match e.as_database_error() { - Some(d) => { - if d.code() == Some(Cow::from("XX000")) { - // Explicitly get and set the model - let project_info = self.collection.get_project_info().await?; - let pipeline = self - .pipeline - .as_mut() - .context("Need pipeline to call fetch_all on query builder with remote embeddings")?; - pipeline.set_project_info(project_info); - pipeline.verify_in_database(false).await?; - let model = pipeline - .model - .as_ref() - .context("Pipeline must be verified to perform vector search with remote embeddings")?; - - // If the model runtime is python, the error was not caused by an unsupported runtime - if model.runtime == ModelRuntime::Python { - return Err(anyhow::anyhow!(e)); - } - - let hnsw_parameters = query_parameters - .as_object_mut() - .context("Query parameters must be a Json object")? - .remove("hnsw"); - - let remote_embeddings = - build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; - let mut embeddings = remote_embeddings - .embed(vec![self - .query_string - .to_owned() - .context("Must have query_string to call fetch_all on query_builder with remote embeddings")?]) - .await?; - let embedding = std::mem::take(&mut embeddings[0]); - - let mut embedding_cte = Query::select(); - embedding_cte - .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); - - let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); - embedding_cte.table_name(Alias::new("embedding")); - let mut with_clause = WithClause::new(); - with_clause.cte(embedding_cte); - - let (sql, values) = self - .query - .clone() - .with(with_clause) - .build_sqlx(PostgresQueryBuilder); - - if let Some(parameters) = hnsw_parameters { - let mut transaction = pool.begin().await?; - let ef_search = parameters["ef_search"] - .try_to_i64() - .context("ef_search must be an integer")?; - sqlx::query(&query_builder!( - "SET LOCAL hnsw.ef_search = %d", - ef_search - )) - .execute(&mut *transaction) - .await?; - let results = sqlx::query_as_with(&sql, values) - .fetch_all(&mut *transaction) - .await; - transaction.commit().await?; - results - } else { - sqlx::query_as_with(&sql, values).fetch_all(&pool).await - } - .map_err(|e| anyhow::anyhow!(e)) - } else { - Err(anyhow::anyhow!(e)) - } - } - None => Err(anyhow::anyhow!(e)), - }, - }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) - } - - // This is mostly so our SDKs in other languages have some way to debug - pub fn to_full_string(&self) -> String { - self.to_string() - } -} - -impl std::fmt::Display for QueryBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let query = self.query.clone().with(self.with.clone()); - write!(f, "{}", query.to_string(PostgresQueryBuilder)) + let results = self + .collection + .vector_search( + self.query, + self.pipeline + .as_mut() + .context("cannot fetch all without first calling vector_recall")?, + ) + .await?; + results + .into_iter() + .map(|mut v| { + Ok(( + v["score"].as_f64().context("Error converting core")?, + v["chunk"] + .as_str() + .context("Error converting chunk")? + .to_string(), + std::mem::take(&mut v["document"]).into(), + )) + }) + .collect() } } diff --git a/pgml-sdks/pgml/src/remote_embeddings.rs b/pgml-sdks/pgml/src/remote_embeddings.rs index bcb84146c..f010c6c50 100644 --- a/pgml-sdks/pgml/src/remote_embeddings.rs +++ b/pgml-sdks/pgml/src/remote_embeddings.rs @@ -1,5 +1,5 @@ use reqwest::{Client, RequestBuilder}; -use sqlx::postgres::PgPool; +use sqlx::PgConnection; use std::env; use tracing::instrument; @@ -8,7 +8,7 @@ use crate::{model::ModelRuntime, models, query_builder, types::Json}; pub fn build_remote_embeddings<'a>( source: ModelRuntime, model_name: &'a str, - _model_parameters: &'a Json, + _model_parameters: Option<&'a Json>, ) -> anyhow::Result + Sync + Send + 'a>> { match source { // OpenAI endpoint for embedddings does not take any model parameters @@ -41,39 +41,40 @@ pub trait RemoteEmbeddings<'a> { self.parse_response(response) } - #[instrument(skip(self, pool))] + #[instrument(skip(self))] async fn get_chunks( &self, embeddings_table_name: &str, chunks_table_name: &str, - splitter_id: i64, - chunk_ids: &Option>, - pool: &PgPool, + chunk_ids: Option<&Vec>, + connection: &mut PgConnection, limit: Option, ) -> anyhow::Result> { - let limit = limit.unwrap_or(1000); - - match chunk_ids { - Some(cids) => sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE splitter_id = $1 AND id NOT IN (SELECT chunk_id FROM %s) AND id = ANY ($2) LIMIT $3", - chunks_table_name, - embeddings_table_name - )) - .bind(splitter_id) - .bind(cids) - .bind(limit) - .fetch_all(pool) - .await, - None => sqlx::query_as(&query_builder!( - "SELECT * FROM %s WHERE splitter_id = $1 AND id NOT IN (SELECT chunk_id FROM %s) LIMIT $2", - chunks_table_name, - embeddings_table_name - )) - .bind(splitter_id) - .bind(limit) - .fetch_all(pool) + // Requires _query_text be declared out here so it lives long enough + let mut _query_text = "".to_string(); + let query = match chunk_ids { + Some(chunk_ids) => { + _query_text = + query_builder!("SELECT * FROM %s WHERE id = ANY ($1)", chunks_table_name); + sqlx::query_as(_query_text.as_str()) + .bind(chunk_ids) + .bind(limit) + } + None => { + let limit = limit.unwrap_or(1000); + _query_text = query_builder!( + "SELECT * FROM %s WHERE id NOT IN (SELECT chunk_id FROM %s) LIMIT $1", + chunks_table_name, + embeddings_table_name + ); + sqlx::query_as(_query_text.as_str()).bind(limit) + } + }; + + query + .fetch_all(connection) .await - }.map_err(|e| anyhow::anyhow!(e)) + .map_err(|e| anyhow::anyhow!(e)) } #[instrument(skip(self, response))] @@ -99,41 +100,39 @@ pub trait RemoteEmbeddings<'a> { Ok(embeddings) } - #[instrument(skip(self, pool))] + #[instrument(skip(self))] async fn generate_embeddings( &self, embeddings_table_name: &str, chunks_table_name: &str, - splitter_id: i64, - chunk_ids: Option>, - pool: &PgPool, + mut chunk_ids: Option<&Vec>, + connection: &mut PgConnection, ) -> anyhow::Result<()> { loop { let chunks = self .get_chunks( embeddings_table_name, chunks_table_name, - splitter_id, - &chunk_ids, - pool, + chunk_ids, + connection, None, ) .await?; if chunks.is_empty() { break; } - let (chunk_ids, chunk_texts): (Vec, Vec) = chunks + let (retrieved_chunk_ids, chunk_texts): (Vec, Vec) = chunks .into_iter() .map(|chunk| (chunk.id, chunk.chunk)) .unzip(); let embeddings = self.embed(chunk_texts).await?; let query_string_values = (0..embeddings.len()) - .map(|i| format!("(${}, ${})", i * 2 + 1, i * 2 + 2)) + .map(|i| query_builder!("($%d, $%d)", i * 2 + 1, i * 2 + 2)) .collect::>() .join(","); let query_string = format!( - "INSERT INTO %s (chunk_id, embedding) VALUES {}", + "INSERT INTO %s (chunk_id, embedding) VALUES {} ON CONFLICT (chunk_id) DO UPDATE SET embedding = EXCLUDED.embedding", query_string_values ); @@ -141,10 +140,13 @@ pub trait RemoteEmbeddings<'a> { let mut query = sqlx::query(&query); for i in 0..embeddings.len() { - query = query.bind(chunk_ids[i]).bind(&embeddings[i]); + query = query.bind(retrieved_chunk_ids[i]).bind(&embeddings[i]); } - query.execute(pool).await?; + query.execute(&mut *connection).await?; + + // Set it to none so if it is not None, we don't just retrived the same chunks over and over + chunk_ids = None; } Ok(()) } @@ -183,8 +185,11 @@ mod tests { #[tokio::test] async fn openai_remote_embeddings() -> anyhow::Result<()> { let params = serde_json::json!({}).into(); - let openai_remote_embeddings = - build_remote_embeddings(ModelRuntime::OpenAI, "text-embedding-ada-002", ¶ms)?; + let openai_remote_embeddings = build_remote_embeddings( + ModelRuntime::OpenAI, + "text-embedding-ada-002", + Some(¶ms), + )?; let embedding_size = openai_remote_embeddings.get_embedding_size().await?; assert!(embedding_size > 0); Ok(()) diff --git a/pgml-sdks/pgml/src/search_query_builder.rs b/pgml-sdks/pgml/src/search_query_builder.rs new file mode 100644 index 000000000..3fb6a0db4 --- /dev/null +++ b/pgml-sdks/pgml/src/search_query_builder.rs @@ -0,0 +1,530 @@ +use anyhow::Context; +use serde::Deserialize; +use std::collections::HashMap; + +use sea_query::{ + Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, + SimpleExpr, WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; + +use crate::{ + collection::Collection, + debug_sea_query, + filter_builder::FilterBuilder, + model::ModelRuntime, + models, + pipeline::Pipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, +}; + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidSemanticSearchAction { + query: String, + parameters: Option, + boost: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidFullTextSearchAction { + query: String, + boost: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQueryActions { + full_text_search: Option>, + semantic_search: Option>, + filter: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQuery { + query: ValidQueryActions, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] + limit: Option, +} + +pub async fn build_search_query( + collection: &Collection, + query: Json, + pipeline: &Pipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let valid_query: ValidQuery = serde_json::from_value(query.0.clone())?; + let limit = valid_query.limit.unwrap_or(10); + + let pipeline_table = format!("{}.pipelines", collection.name); + let documents_table = format!("{}.documents", collection.name); + + let mut score_table_names = Vec::new(); + let mut with_clause = WithClause::new(); + let mut sum_expression: Option = None; + + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + for (key, vsa) in valid_query.query.semantic_search.unwrap_or_default() { + let model_runtime = pipeline + .parsed_schema + .as_ref() + .map(|s| { + // Any of these errors means they have a malformed query + anyhow::Ok( + s.get(&key) + .as_ref() + .context(format!("Bad query - {key} does not exist in schema"))? + .semantic_search + .as_ref() + .context(format!( + "Bad query - {key} does not have any directive to semantic_search" + ))? + .model + .runtime, + ) + }) + .transpose()? + .unwrap_or(ModelRuntime::Python); + + // Build the CTE we actually use later + let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let cte_name = format!("{key}_embedding_score"); + let boost = vsa.boost.unwrap_or(1.); + let mut score_cte_non_recursive = Query::select(); + let mut score_cte_recurisive = Query::select(); + match model_runtime { + ModelRuntime::Python => { + // Build the embedding CTE + let mut embedding_cte = Query::select(); + embedding_cte.expr_as( + Func::cust(SIden::Str("pgml.embed")).args([ + Expr::cust(format!( + "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", + )), + Expr::cust_with_values("text => $1", [&vsa.query]), + Expr::cust_with_values("kwargs => $1", [vsa.parameters.unwrap_or_default().0]), + ]), + Alias::new("embedding"), + ); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); + with_clause.cte(embedding_cte); + + score_cte_non_recursive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("documents"), SIden::Str("id"))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .expr(Expr::cust(r#"ARRAY[documents.id] as previous_document_ids"#)) + .expr(Expr::cust(format!( + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# + ))) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc ) + .limit(1); + + score_cte_recurisive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust(format!(r#""{cte_name}".previous_document_ids || documents.id"#))) + .expr(Expr::cust(format!( + r#"(1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector)) * {boost} AS score"# + ))) + .and_where(Expr::cust(format!(r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"#))) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc ) + .limit(1); + } + ModelRuntime::OpenAI => { + // We can unwrap here as we know this is all set from above + let model = &pipeline + .parsed_schema + .as_ref() + .unwrap() + .get(&key) + .unwrap() + .semantic_search + .as_ref() + .unwrap() + .model; + + // Get the remote embedding + let embedding = { + let remote_embeddings = build_remote_embeddings( + model.runtime, + &model.name, + vsa.parameters.as_ref(), + )?; + let mut embeddings = remote_embeddings.embed(vec![vsa.query]).await?; + std::mem::take(&mut embeddings[0]) + }; + + score_cte_non_recursive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust("ARRAY[documents.id] as previous_document_ids")) + .expr(Expr::cust_with_values( + format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), + [embedding.clone()], + )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by_expr( + Expr::cust_with_values( + "embeddings.embedding <=> $1::vector", + [embedding.clone()], + ), + Order::Asc, + ) + .limit(1); + + score_cte_recurisive + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr(Expr::cust(format!( + r#""{cte_name}".previous_document_ids || documents.id"# + ))) + .expr(Expr::cust_with_values( + format!("(1 - (embeddings.embedding <=> $1::vector)) * {boost} AS score"), + [embedding.clone()], + )) + .and_where(Expr::cust(format!( + r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"# + ))) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by_expr( + Expr::cust_with_values( + "embeddings.embedding <=> $1::vector", + [embedding.clone()], + ), + Order::Asc, + ) + .limit(1); + } + } + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + score_cte_non_recursive.cond_where(filter.clone()); + score_cte_recurisive.cond_where(filter); + } + + let score_cte = Query::select() + .expr(Expr::cust("*")) + .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) + .union(sea_query::UnionType::All, score_cte_recurisive) + .to_owned(); + + let mut score_cte = CommonTableExpression::from_select(score_cte); + score_cte.table_name(Alias::new(&cte_name)); + with_clause.cte(score_cte); + + // Add to the sum expression + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + } else { + Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + }; + score_table_names.push(cte_name); + } + + for (key, vma) in valid_query.query.full_text_search.unwrap_or_default() { + let full_text_table = format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let boost = vma.boost.unwrap_or(1.0); + + // Build the score CTE + let cte_name = format!("{key}_tsvectors_score"); + + let mut score_cte_non_recursive = Query::select() + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr_as( + Expr::cust_with_values( + format!( + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32) * {boost}"#, + ), + [&vma.query], + ), + Alias::new("score") + ) + .expr(Expr::cust( + "ARRAY[documents.id] as previous_document_ids", + )) + .from_as( + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + ) + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [&vma.query], + )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by(SIden::Str("score"), Order::Desc) + .limit(1). + to_owned(); + + let mut score_cte_recursive = Query::select() + .column((SIden::Str("documents"), SIden::Str("id"))) + .expr_as( + Expr::cust_with_values( + format!( + r#"ts_rank(tsvectors.ts, plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1), 32) * {boost}"#, + ), + [&vma.query], + ), + Alias::new("score") + ) + .expr(Expr::cust(format!( + r#""{cte_name}".previous_document_ids || documents.id"# + ))) + .from_as( + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + ) + .join( + JoinType::Join, + SIden::String(cte_name.clone()), + Expr::cust("1 = 1"), + ) + .and_where(Expr::cust(format!( + r#"NOT documents.id = ANY("{cte_name}".previous_document_ids)"# + ))) + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [&vma.query], + )) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("tsvectors"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .order_by(SIden::Str("score"), Order::Desc) + .limit(1) + .to_owned(); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + score_cte_recursive.cond_where(filter.clone()); + score_cte_non_recursive.cond_where(filter); + } + + let score_cte = Query::select() + .expr(Expr::cust("*")) + .from_subquery(score_cte_non_recursive, Alias::new("non_recursive")) + .union(sea_query::UnionType::All, score_cte_recursive) + .to_owned(); + + let mut score_cte = CommonTableExpression::from_select(score_cte); + score_cte.table_name(Alias::new(&cte_name)); + with_clause.cte(score_cte); + + // Add to the sum expression + sum_expression = if let Some(expr) = sum_expression { + Some(expr.add(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#)))) + } else { + Some(Expr::cust(format!(r#"COALESCE("{cte_name}".score, 0.0)"#))) + }; + score_table_names.push(cte_name); + } + + let query = if let Some(select_from) = score_table_names.first() { + let score_table_names_e: Vec = score_table_names + .clone() + .into_iter() + .map(|t| Expr::col((SIden::String(t), SIden::Str("id"))).into()) + .collect(); + let mut main_query = Query::select(); + for i in 1..score_table_names_e.len() { + main_query.full_outer_join( + SIden::String(score_table_names[i].to_string()), + Expr::col(( + SIden::String(score_table_names[i].to_string()), + SIden::Str("id"), + )) + .eq(Func::coalesce(score_table_names_e[0..i].to_vec())), + ); + } + let id_select_expression = Func::coalesce(score_table_names_e); + + let sum_expression = sum_expression + .context("query requires some scoring through full_text_search or semantic_search")?; + main_query + .expr_as(Expr::expr(id_select_expression.clone()), Alias::new("id")) + .expr_as(sum_expression, Alias::new("score")) + .column(SIden::Str("document")) + .from(SIden::String(select_from.to_string())) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))).eq(id_select_expression), + ) + .order_by(SIden::Str("score"), Order::Desc) + .limit(limit); + + let mut main_query = CommonTableExpression::from_select(main_query); + main_query.table_name(Alias::new("main")); + with_clause.cte(main_query); + + // Insert into searches table + let searches_table = format!("{}_{}.searches", collection.name, pipeline.name); + let searches_insert_query = Query::insert() + .into_table(searches_table.to_table_tuple()) + .columns([SIden::Str("query")]) + .values([query.0.into()])? + .returning_col(SIden::Str("id")) + .to_owned(); + let mut searches_insert_query = CommonTableExpression::new() + .query(searches_insert_query) + .to_owned(); + searches_insert_query.table_name(Alias::new("searches_insert")); + with_clause.cte(searches_insert_query); + + // Insert into search_results table + let search_results_table = format!("{}_{}.search_results", collection.name, pipeline.name); + let jsonb_builder = score_table_names.iter().fold(String::new(), |acc, t| { + format!("{acc}, '{t}', (SELECT score FROM {t} WHERE {t}.id = main.id)") + }); + let jsonb_builder = format!("JSONB_BUILD_OBJECT('total', score{jsonb_builder})"); + let search_results_insert_query = Query::insert() + .into_table(search_results_table.to_table_tuple()) + .columns([ + SIden::Str("search_id"), + SIden::Str("document_id"), + SIden::Str("scores"), + SIden::Str("rank"), + ]) + .select_from( + Query::select() + .expr(Expr::cust("(SELECT id FROM searches_insert)")) + .column(SIden::Str("id")) + .expr(Expr::cust(jsonb_builder)) + .expr(Expr::cust("row_number() over()")) + .from(SIden::Str("main")) + .to_owned(), + )? + .to_owned(); + let mut search_results_insert_query = CommonTableExpression::new() + .query(search_results_insert_query) + .to_owned(); + search_results_insert_query.table_name(Alias::new("search_results_insert")); + with_clause.cte(search_results_insert_query); + + Query::select() + .expr(Expr::cust( + "JSONB_BUILD_OBJECT('search_id', (SELECT id FROM searches_insert), 'results', JSON_AGG(main.*))", + )) + .from(SIden::Str("main")) + .to_owned() + } else { + // TODO: Maybe let users filter documents only here? + anyhow::bail!("If you are only looking to filter documents checkout the `get_documents` method on the Collection") + }; + + // For whatever reason, sea query does not like multiple ctes if the cte is recursive + let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); + let sql = sql.replace("WITH ", "WITH RECURSIVE "); + debug_sea_query!(DOCUMENT_SEARCH, sql, values); + Ok((sql, values)) +} diff --git a/pgml-sdks/pgml/src/single_field_pipeline.rs b/pgml-sdks/pgml/src/single_field_pipeline.rs new file mode 100644 index 000000000..4acba800f --- /dev/null +++ b/pgml-sdks/pgml/src/single_field_pipeline.rs @@ -0,0 +1,153 @@ +use crate::model::Model; +use crate::splitter::Splitter; +use crate::types::Json; +use crate::Pipeline; + +#[cfg(feature = "python")] +use crate::{model::ModelPython, splitter::SplitterPython, types::JsonPython}; + +#[allow(dead_code)] +fn build_pipeline( + name: &str, + model: Option, + splitter: Option, + parameters: Option, +) -> Pipeline { + let parameters = parameters.unwrap_or_default(); + let schema = if let Some(model) = model { + let mut schema = serde_json::json!({ + "text": { + "semantic_search": { + "model": model.name, + "parameters": model.parameters, + "hnsw": parameters["hnsw"] + } + } + }); + if let Some(splitter) = splitter { + schema["text"]["splitter"] = serde_json::json!({ + "model": splitter.name, + "parameters": splitter.parameters + }); + } + if parameters["full_text_search"]["active"] + .as_bool() + .unwrap_or_default() + { + schema["text"]["full_text_search"] = serde_json::json!({ + "configuration": parameters["full_text_search"]["configuration"].as_str().map(|v| v.to_string()).unwrap_or_else(|| "english".to_string()) + }); + } + Some(schema.into()) + } else { + None + }; + Pipeline::new(name, schema).expect("Error converting pipeline into new multifield pipeline") +} + +#[cfg(feature = "python")] +#[pyo3::prelude::pyfunction] +#[allow(non_snake_case)] // This doesn't seem to be working +pub fn SingleFieldPipeline( + name: &str, + model: Option, + splitter: Option, + parameters: Option, +) -> Pipeline { + let model = model.map(|m| *m.wrapped); + let splitter = splitter.map(|s| *s.wrapped); + let parameters = parameters.map(|p| p.wrapped); + build_pipeline(name, model, splitter, parameters) +} + +#[cfg(feature = "javascript")] +#[allow(non_snake_case)] +pub fn SingleFieldPipeline<'a>( + mut cx: neon::context::FunctionContext<'a>, +) -> neon::result::JsResult<'a, neon::types::JsValue> { + use rust_bridge::javascript::{FromJsType, IntoJsResult}; + let name = cx.argument(0)?; + let name = String::from_js_type(&mut cx, name)?; + + let model = cx.argument_opt(1); + let model = >::from_option_js_type(&mut cx, model)?; + + let splitter = cx.argument_opt(2); + let splitter = >::from_option_js_type(&mut cx, splitter)?; + + let parameters = cx.argument_opt(3); + let parameters = >::from_option_js_type(&mut cx, parameters)?; + + let pipeline = build_pipeline(&name, model, splitter, parameters); + let x = crate::pipeline::PipelineJavascript::from(pipeline); + x.into_js_result(&mut cx) +} + +mod tests { + #[test] + fn pipeline_to_pipeline() -> anyhow::Result<()> { + use super::*; + use serde_json::json; + + let model = Model::new( + Some("test_model".to_string()), + Some("pgml".to_string()), + Some( + json!({ + "test_parameter": 10 + }) + .into(), + ), + ); + let splitter = Splitter::new( + Some("test_splitter".to_string()), + Some( + json!({ + "test_parameter": 11 + }) + .into(), + ), + ); + let parameters = json!({ + "full_text_search": { + "active": true, + "configuration": "test_configuration" + }, + "hnsw": { + "m": 16, + "ef_construction": 64 + } + }); + let pipeline = build_pipeline( + "test_name", + Some(model), + Some(splitter), + Some(parameters.into()), + ); + let schema = json!({ + "text": { + "splitter": { + "model": "test_splitter", + "parameters": { + "test_parameter": 11 + } + }, + "semantic_search": { + "model": "test_model", + "parameters": { + "test_parameter": 10 + }, + "hnsw": { + "m": 16, + "ef_construction": 64 + } + }, + "full_text_search": { + "configuration": "test_configuration" + } + } + }); + assert_eq!(schema, pipeline.schema.unwrap().0); + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/splitter.rs b/pgml-sdks/pgml/src/splitter.rs index 85e85e3a8..96b1ed9da 100644 --- a/pgml-sdks/pgml/src/splitter.rs +++ b/pgml-sdks/pgml/src/splitter.rs @@ -1,17 +1,17 @@ -use anyhow::Context; use rust_bridge::{alias, alias_methods}; -use sqlx::postgres::{PgConnection, PgPool}; +use sqlx::{postgres::PgConnection, Pool, Postgres}; use tracing::instrument; use crate::{ collection::ProjectInfo, - get_or_initialize_pool, models, queries, + models, queries, types::{DateTime, Json}, }; #[cfg(feature = "python")] use crate::types::JsonPython; +#[allow(dead_code)] #[derive(Debug, Clone)] pub(crate) struct SplitterDatabaseData { pub id: i64, @@ -23,7 +23,6 @@ pub(crate) struct SplitterDatabaseData { pub struct Splitter { pub name: String, pub parameters: Json, - project_info: Option, pub(crate) database_data: Option, } @@ -54,28 +53,25 @@ impl Splitter { Self { name, parameters, - project_info: None, database_data: None, } } #[instrument(skip(self))] - pub(crate) async fn verify_in_database(&mut self, throw_if_exists: bool) -> anyhow::Result<()> { + pub(crate) async fn verify_in_database( + &mut self, + project_info: &ProjectInfo, + throw_if_exists: bool, + pool: &Pool, + ) -> anyhow::Result<()> { if self.database_data.is_none() { - let pool = self.get_pool().await?; - - let project_info = self - .project_info - .as_ref() - .expect("Cannot verify splitter without project info"); - let splitter: Option = sqlx::query_as( "SELECT * FROM pgml.splitters WHERE project_id = $1 AND name = $2 and parameters = $3", ) .bind(project_info.id) .bind(&self.name) .bind(&self.parameters) - .fetch_optional(&pool) + .fetch_optional(pool) .await?; let splitter = if let Some(s) = splitter { @@ -88,7 +84,7 @@ impl Splitter { .bind(project_info.id) .bind(&self.name) .bind(&self.parameters) - .fetch_one(&pool) + .fetch_one(pool) .await? }; @@ -106,51 +102,6 @@ impl Splitter { .await?; Ok(()) } - - pub(crate) fn set_project_info(&mut self, project_info: ProjectInfo) { - self.project_info = Some(project_info) - } - - #[instrument(skip(self))] - pub(crate) async fn to_dict(&mut self) -> anyhow::Result { - self.verify_in_database(false).await?; - - let database_data = self - .database_data - .as_ref() - .context("Splitter must be verified to call to_dict")?; - - Ok(serde_json::json!({ - "id": database_data.id, - "created_at": database_data.created_at, - "name": self.name, - "parameters": *self.parameters, - }) - .into()) - } - - async fn get_pool(&self) -> anyhow::Result { - let database_url = &self - .project_info - .as_ref() - .context("Project info required to call method splitter.get_pool()")? - .database_url; - get_or_initialize_pool(database_url).await - } -} - -impl From for Splitter { - fn from(x: models::PipelineWithModelAndSplitter) -> Self { - Self { - name: x.splitter_name, - parameters: x.splitter_parameters, - project_info: None, - database_data: Some(SplitterDatabaseData { - id: x.splitter_id, - created_at: x.splitter_created_at, - }), - } - } } impl From for Splitter { @@ -158,7 +109,6 @@ impl From for Splitter { Self { name: splitter.name, parameters: splitter.parameters, - project_info: None, database_data: Some(SplitterDatabaseData { id: splitter.id, created_at: splitter.created_at, diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 00dd556f7..d20089463 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -74,7 +74,7 @@ impl Stream for TransformerStream { let s: *mut Self = s; let s = Box::leak(Box::from_raw(s)); s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(s.transaction.as_mut().unwrap()), + sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), )); } } @@ -94,7 +94,7 @@ impl Stream for TransformerStream { let s: *mut Self = s; let s = Box::leak(Box::from_raw(s)); s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(s.transaction.as_mut().unwrap()), + sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), )); } } diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index bdf7308a3..1a51e4f20 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -3,12 +3,12 @@ use futures::{Stream, StreamExt}; use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::ops::{Deref, DerefMut}; /// A wrapper around serde_json::Value // #[derive(sqlx::Type, sqlx::FromRow, Debug)] -#[derive(alias_manual, sqlx::Type, Debug, Clone)] +#[derive(alias_manual, sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] #[sqlx(transparent)] pub struct Json(pub serde_json::Value); diff --git a/pgml-sdks/pgml/src/utils.rs b/pgml-sdks/pgml/src/utils.rs index a8c040bc9..c1d447bb0 100644 --- a/pgml-sdks/pgml/src/utils.rs +++ b/pgml-sdks/pgml/src/utils.rs @@ -3,6 +3,11 @@ use indicatif::{ProgressBar, ProgressStyle}; use lopdf::Document; use std::fs; use std::path::Path; +use std::time::Duration; + +use serde::de::{self, Visitor}; +use serde::Deserializer; +use std::fmt; /// A more type flexible version of format! #[macro_export] @@ -25,18 +30,50 @@ macro_rules! query_builder { }}; } -pub fn default_progress_spinner(size: u64) -> ProgressBar { - ProgressBar::new(size).with_style( - ProgressStyle::with_template("[{elapsed_precise}] {spinner:0.cyan/blue} {prefix}: {msg}") - .unwrap(), - ) +/// Used to debug sqlx queries +#[macro_export] +macro_rules! debug_sqlx_query { + ($name:expr, $query:expr) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust(sql)).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; + + ($name:expr, $query:expr, $( $x:expr ),*) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust_with_values(sql, [$( + sea_query::Value::from($x.clone()), + )*])).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; +} + +/// Used to debug sea_query queries +#[macro_export] +macro_rules! debug_sea_query { + ($name:expr, $query:expr, $values:expr) => {{ + let name = stringify!($name); + let sql = $query.to_string(); + let sql = sea_query::Query::select().expr(sea_query::Expr::cust_with_values(sql, $values.clone().0)).to_string(sea_query::PostgresQueryBuilder); + let sql = sql.replacen("SELECT", "", 1); + let span = tracing::span!(tracing::Level::DEBUG, "debug_query"); + tracing::event!(parent: &span, tracing::Level::DEBUG, %name, %sql); + }}; } pub fn default_progress_bar(size: u64) -> ProgressBar { - ProgressBar::new(size).with_style( + let bar = ProgressBar::new(size).with_style( ProgressStyle::with_template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} ") .unwrap(), - ) + ); + bar.enable_steady_tick(Duration::from_millis(100)); + bar } pub fn get_file_contents(path: &Path) -> anyhow::Result { @@ -63,3 +100,40 @@ pub fn get_file_contents(path: &Path) -> anyhow::Result { .with_context(|| format!("Error reading file: {}", path.display()))?, }) } + +struct U64Visitor; +impl<'de> Visitor<'de> for U64Visitor { + type Value = u64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("some number") + } + + fn visit_i32(self, value: i32) -> Result + where + E: de::Error, + { + Ok(value as u64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: de::Error, + { + Ok(value) + } + + fn visit_f64(self, value: f64) -> Result + where + E: de::Error, + { + Ok(value as u64) + } +} + +pub fn deserialize_u64<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserializer.deserialize_u64(U64Visitor).map(Some) +} diff --git a/pgml-sdks/pgml/src/vector_search_query_builder.rs b/pgml-sdks/pgml/src/vector_search_query_builder.rs new file mode 100644 index 000000000..df4f54e79 --- /dev/null +++ b/pgml-sdks/pgml/src/vector_search_query_builder.rs @@ -0,0 +1,240 @@ +use anyhow::Context; +use serde::Deserialize; +use std::collections::HashMap; + +use sea_query::{ + Alias, CommonTableExpression, Expr, Func, JoinType, Order, PostgresQueryBuilder, Query, + WithClause, +}; +use sea_query_binder::{SqlxBinder, SqlxValues}; + +use crate::{ + collection::Collection, + debug_sea_query, + filter_builder::FilterBuilder, + model::ModelRuntime, + models, + pipeline::Pipeline, + remote_embeddings::build_remote_embeddings, + types::{IntoTableNameAndSchema, Json, SIden}, +}; + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidField { + query: String, + parameters: Option, + full_text_filter: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQueryActions { + fields: Option>, + filter: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +struct ValidQuery { + query: ValidQueryActions, + // Need this when coming from JavaScript as everything is an f64 from JS + #[serde(default, deserialize_with = "crate::utils::deserialize_u64")] + limit: Option, +} + +pub async fn build_vector_search_query( + query: Json, + collection: &Collection, + pipeline: &Pipeline, +) -> anyhow::Result<(String, SqlxValues)> { + let valid_query: ValidQuery = serde_json::from_value(query.0)?; + let limit = valid_query.limit.unwrap_or(10); + let fields = valid_query.query.fields.unwrap_or_default(); + + if fields.is_empty() { + anyhow::bail!("at least one field is required to search over") + } + + let pipeline_table = format!("{}.pipelines", collection.name); + let documents_table = format!("{}.documents", collection.name); + + let mut queries = Vec::new(); + let mut with_clause = WithClause::new(); + + let mut pipeline_cte = Query::select(); + pipeline_cte + .from(pipeline_table.to_table_tuple()) + .columns([models::PipelineIden::Schema]) + .and_where(Expr::col(models::PipelineIden::Name).eq(&pipeline.name)); + let mut pipeline_cte = CommonTableExpression::from_select(pipeline_cte); + pipeline_cte.table_name(Alias::new("pipeline")); + with_clause.cte(pipeline_cte); + + for (key, vf) in fields { + let model_runtime = pipeline + .parsed_schema + .as_ref() + .map(|s| { + // Any of these errors means they have a malformed query + anyhow::Ok( + s.get(&key) + .as_ref() + .context(format!("Bad query - {key} does not exist in schema"))? + .semantic_search + .as_ref() + .context(format!( + "Bad query - {key} does not have any directive to semantic_search" + ))? + .model + .runtime, + ) + }) + .transpose()? + .unwrap_or(ModelRuntime::Python); + + let chunks_table = format!("{}_{}.{}_chunks", collection.name, pipeline.name, key); + let embeddings_table = format!("{}_{}.{}_embeddings", collection.name, pipeline.name, key); + + let mut query = Query::select(); + + match model_runtime { + ModelRuntime::Python => { + // Build the embedding CTE + let mut embedding_cte = Query::select(); + embedding_cte.expr_as( + Func::cust(SIden::Str("pgml.embed")).args([ + Expr::cust(format!( + "transformer => (SELECT schema #>> '{{{key},semantic_search,model}}' FROM pipeline)", + )), + Expr::cust_with_values("text => $1", [vf.query]), + Expr::cust_with_values("kwargs => $1", [vf.parameters.unwrap_or_default().0]), + ]), + Alias::new("embedding"), + ); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new(format!("{key}_embedding"))); + with_clause.cte(embedding_cte); + + query + .expr(Expr::cust(format!( + r#"1 - (embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector) AS score"# + ))) + .order_by_expr(Expr::cust(format!( + r#"embeddings.embedding <=> (SELECT embedding FROM "{key}_embedding")::vector"# + )), Order::Asc); + } + ModelRuntime::OpenAI => { + // We can unwrap here as we know this is all set from above + let model = &pipeline + .parsed_schema + .as_ref() + .unwrap() + .get(&key) + .unwrap() + .semantic_search + .as_ref() + .unwrap() + .model; + + // Get the remote embedding + let embedding = { + let remote_embeddings = build_remote_embeddings( + model.runtime, + &model.name, + vf.parameters.as_ref(), + )?; + let mut embeddings = + remote_embeddings.embed(vec![vf.query.to_string()]).await?; + std::mem::take(&mut embeddings[0]) + }; + + // Build the score CTE + query + .expr(Expr::cust_with_values( + r#"1 - (embeddings.embedding <=> $1::vector) AS score"#, + [embedding.clone()], + )) + .order_by_expr( + Expr::cust_with_values( + r#"embeddings.embedding <=> $1::vector"#, + [embedding], + ), + Order::Asc, + ); + } + } + + query + .column((SIden::Str("documents"), SIden::Str("id"))) + .column((SIden::Str("chunks"), SIden::Str("chunk"))) + .column((SIden::Str("documents"), SIden::Str("document"))) + .from_as(embeddings_table.to_table_tuple(), Alias::new("embeddings")) + .join_as( + JoinType::InnerJoin, + chunks_table.to_table_tuple(), + Alias::new("chunks"), + Expr::col((SIden::Str("chunks"), SIden::Str("id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), + ) + .join_as( + JoinType::InnerJoin, + documents_table.to_table_tuple(), + Alias::new("documents"), + Expr::col((SIden::Str("documents"), SIden::Str("id"))) + .equals((SIden::Str("chunks"), SIden::Str("document_id"))), + ) + .limit(limit); + + if let Some(filter) = &valid_query.query.filter { + let filter = FilterBuilder::new(filter.clone().0, "documents", "document").build()?; + query.cond_where(filter); + } + + if let Some(full_text_search) = &vf.full_text_filter { + let full_text_table = + format!("{}_{}.{}_tsvectors", collection.name, pipeline.name, key); + query + .and_where(Expr::cust_with_values( + format!( + r#"tsvectors.ts @@ plainto_tsquery((SELECT oid FROM pg_ts_config WHERE cfgname = (SELECT schema #>> '{{{key},full_text_search,configuration}}' FROM pipeline)), $1)"#, + ), + [full_text_search], + )) + .join_as( + JoinType::InnerJoin, + full_text_table.to_table_tuple(), + Alias::new("tsvectors"), + Expr::col((SIden::Str("tsvectors"), SIden::Str("chunk_id"))) + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))) + ); + } + + let mut wrapper_query = Query::select(); + wrapper_query + .columns([ + SIden::Str("document"), + SIden::Str("chunk"), + SIden::Str("score"), + ]) + .from_subquery(query, Alias::new("s")); + + queries.push(wrapper_query); + } + + // Union all of the queries together + let mut query = queries.pop().context("no query")?; + for q in queries.into_iter() { + query.union(sea_query::UnionType::All, q); + } + + // Resort and limit + query + .order_by(SIden::Str("score"), Order::Desc) + .limit(limit); + + let (sql, values) = query.with(with_clause).build_sqlx(PostgresQueryBuilder); + + debug_sea_query!(VECTOR_SEARCH, sql, values); + Ok((sql, values)) +} diff --git a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs index cf4f04316..a453bf14f 100644 --- a/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs +++ b/pgml-sdks/rust-bridge/rust-bridge-macros/src/python.rs @@ -221,8 +221,9 @@ pub fn generate_python_methods( let st = r.to_string(); Some(if st.contains('&') { let st = st.replace("self", &wrapped_type_ident.to_string()); - let s = syn::parse_str::(&st).unwrap_or_else(|_| panic!("Error converting self type to necessary syn type: {:?}", - r)); + let s = syn::parse_str::(&st).unwrap_or_else(|_| { + panic!("Error converting self type to necessary syn type: {:?}", r) + }); s.to_token_stream() } else { quote! { #wrapped_type_ident } @@ -265,6 +266,7 @@ pub fn generate_python_methods( }; // The new function for pyO3 requires some unique syntax + // The way we use the #convert_from assumes that new has a return type let (signature, middle) = if method_ident == "new" { let signature = quote! { #[new] @@ -296,7 +298,7 @@ pub fn generate_python_methods( use rust_bridge::python::CustomInto; #prepared_wrapper_arguments #middle - let x: Self = x.custom_into(); + let x: #convert_from = x.custom_into(); Ok(x) }; (signature, middle)