From 98cb5345673e882394dc63baddbcbbde374bb976 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Sep 2023 09:02:27 -0700 Subject: [PATCH 01/11] Started migration and hnsw move prep --- pgml-sdks/pgml/src/builtins.rs | 6 +- pgml-sdks/pgml/src/collection.rs | 3 +- pgml-sdks/pgml/src/lib.rs | 44 ++++++----- pgml-sdks/pgml/src/migrations/mod.rs | 78 +++++++++++++++++++ .../pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 42 ++++++++++ pgml-sdks/pgml/src/pipeline.rs | 4 +- pgml-sdks/pgml/src/queries.rs | 5 +- 7 files changed, 154 insertions(+), 28 deletions(-) create mode 100644 pgml-sdks/pgml/src/migrations/mod.rs create mode 100644 pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 60465c130..7dd887a34 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -92,11 +92,11 @@ impl Builtins { #[cfg(test)] mod tests { use super::*; - use crate::init_logger; + use crate::internal_init_logger; #[sqlx::test] async fn can_query() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); let query = "SELECT 10"; let results = builtins.query(query).fetch_all().await?; @@ -106,7 +106,7 @@ mod tests { #[sqlx::test] async fn can_transform() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); let task = Json::from(serde_json::json!("translation_en_to_fr")); let inputs = vec!["test1".to_string(), "test2".to_string()]; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 23fe6df42..2f76ab1b9 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -210,9 +210,10 @@ impl Collection { .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str()) .await?; - let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id) VALUES ($1, $2) ON CONFLICT (name) DO NOTHING RETURNING *") + let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") .bind(&self.name) .bind(project_id) + .bind(crate::SDK_VERSION) .fetch_one(&mut *transaction) .await?; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 8c6c355ec..e6a4868f3 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -12,6 +12,7 @@ use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; +mod migrations; mod builtins; mod collection; mod filter_builder; @@ -34,6 +35,9 @@ pub use model::Model; pub use pipeline::Pipeline; pub use splitter::Splitter; +// This is use when inserting collections to set the sdk_version used during creation +static SDK_VERSION: &'static str = "0.9.2"; + // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need static DATABASE_POOLS: RwLock>> = RwLock::new(None); @@ -74,7 +78,7 @@ impl From<&str> for LogFormat { } #[allow(dead_code)] -fn init_logger(level: Option, format: Option) -> anyhow::Result<()> { +fn internal_init_logger(level: Option, format: Option) -> anyhow::Result<()> { let level = level.unwrap_or_else(|| env::var("LOG_LEVEL").unwrap_or("".to_string())); let level = match level.as_str() { "TRACE" => Level::TRACE, @@ -124,15 +128,15 @@ fn get_or_set_runtime<'a>() -> &'a Runtime { #[cfg(feature = "python")] #[pyo3::prelude::pyfunction] -fn py_init_logger(level: Option, format: Option) -> pyo3::PyResult<()> { - init_logger(level, format).ok(); +fn init_logger(level: Option, format: Option) -> pyo3::PyResult<()> { + internal_init_logger(level, format).ok(); Ok(()) } #[cfg(feature = "python")] #[pyo3::pymodule] fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { - m.add_function(pyo3::wrap_pyfunction!(py_init_logger, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -142,7 +146,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { } #[cfg(feature = "javascript")] -fn js_init_logger( +fn init_logger( mut cx: neon::context::FunctionContext, ) -> neon::result::JsResult { use rust_bridge::javascript::{FromJsType, IntoJsResult}; @@ -150,14 +154,14 @@ fn js_init_logger( let level = >::from_option_js_type(&mut cx, level)?; let format = cx.argument_opt(1); let format = >::from_option_js_type(&mut cx, format)?; - init_logger(level, format).ok(); + internal_init_logger(level, format).ok(); ().into_js_result(&mut cx) } #[cfg(feature = "javascript")] #[neon::main] fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { - cx.export_function("js_init_logger", js_init_logger)?; + cx.export_function("init_logger", init_logger)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; cx.export_function("newSplitter", splitter::SplitterJavascript::new)?; @@ -195,7 +199,7 @@ mod tests { #[sqlx::test] async fn can_create_collection() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_ccc_0", None); assert!(collection.database_data.is_none()); collection.verify_in_database(false).await?; @@ -206,7 +210,7 @@ mod tests { #[sqlx::test] async fn can_add_remove_pipeline() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -236,7 +240,7 @@ mod tests { #[sqlx::test] async fn can_add_remove_pipelines() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline1 = Pipeline::new( @@ -280,7 +284,7 @@ mod tests { #[sqlx::test] async fn sync_multiple_pipelines() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline1 = Pipeline::new( @@ -337,7 +341,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_local_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -372,7 +376,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new( Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), @@ -411,7 +415,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_query_builder() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -448,7 +452,7 @@ mod tests { #[sqlx::test] async fn can_vector_search_with_query_builder_with_remote_embeddings() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new( Some("text-embedding-ada-002".to_string()), Some("openai".to_string()), @@ -489,7 +493,7 @@ mod tests { #[sqlx::test] async fn can_filter_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new(None, None, None); let splitter = Splitter::new(None, None); let mut pipeline = Pipeline::new( @@ -558,7 +562,7 @@ mod tests { #[sqlx::test] async fn can_upsert_and_filter_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -651,7 +655,7 @@ mod tests { #[sqlx::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let mut collection = Collection::new("test_r_c_cpgd_2", None); collection .upsert_documents(generate_dummy_documents(10)) @@ -733,7 +737,7 @@ mod tests { #[sqlx::test] async fn can_filter_and_paginate_get_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); let mut pipeline = Pipeline::new( @@ -836,7 +840,7 @@ mod tests { #[sqlx::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { - init_logger(None, None).ok(); + internal_init_logger(None, None).ok(); let model = Model::new(None, None, None); let splitter = Splitter::new(None, None); let mut pipeline = Pipeline::new( diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs new file mode 100644 index 000000000..e1418698f --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -0,0 +1,78 @@ +use futures::FutureExt; +use itertools::Itertools; +use sqlx::PgPool; +use tracing::instrument; + +use crate::get_or_initialize_pool; + +#[path = "pgml--0.9.1--0.9.2.rs"] +mod pgml091_092; + +// There is probably a better way to write these types and bypass the need for the closure pass +// through, but it is proving to be difficult +// We could also probably remove some unnecessary clones in the call_migrate function if I was savy +// enough to reconcile the lifetimes +type MigrateFn = + &'static dyn Fn(PgPool, Vec) -> futures::future::BoxFuture<'static, anyhow::Result<()>>; +const VERSION_MIGRATIONS: &'static [(&'static str, MigrateFn)] = + &[("0.9.2", &|p, c| pgml091_092::migrate(p, c).boxed())]; + +#[instrument] +pub async fn migrate() -> anyhow::Result<()> { + let pool = get_or_initialize_pool(&None).await?; + let results: Result, _> = + sqlx::query_as("SELECT version, id FROM pgml.collections") + .fetch_all(&pool) + .await; + match results { + Ok(collections) => { + let collections = collections.into_iter().into_group_map(); + for (version, collection_ids) in collections.into_iter() { + call_migrate(pool.clone(), version, collection_ids).await? + } + Ok(()) + } + Err(error) => { + let morphed_error = error + .as_database_error() + .map(|e| e.code().map(|c| c.to_string())); + if let Some(Some(db_error_code)) = morphed_error { + if db_error_code == "42703" { + pgml091_092::migrate(pool, vec![]).await + } else { + anyhow::bail!(error) + } + } else { + anyhow::bail!(error) + } + } + } +} + +async fn call_migrate( + pool: PgPool, + version: String, + collection_ids: Vec, +) -> anyhow::Result<()> { + let position = VERSION_MIGRATIONS.iter().position(|(v, _)| v == &version); + if let Some(p) = position { + // We run each migration in order that needs to be ran for the collections + for (_, callback) in VERSION_MIGRATIONS.iter().skip(p + 1) { + callback(pool.clone(), collection_ids.clone()).await? + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::internal_init_logger; + + #[tokio::test] + async fn test_migrate() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + migrate().await?; + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs new file mode 100644 index 000000000..adcc18b3c --- /dev/null +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -0,0 +1,42 @@ +use crate::{queries, query_builder}; +use sqlx::Executor; +use sqlx::PgPool; +use tracing::instrument; + +#[instrument(skip(pool))] +pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { + let collection_names: Vec = sqlx::query_scalar("SELECT name FROM pgml.collections") + .fetch_all(&pool) + .await?; + for collection_name in collection_names { + let table_name = format!("{}.pipelines", collection_name); + let pipeline_names: Vec = + sqlx::query_scalar(&query_builder!("SELECT name FROM %s", table_name)) + .fetch_all(&pool) + .await?; + for pipeline_name in pipeline_names { + let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); + pool.execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + "hnsw_vector_index", + &table_name, + "embedding vector_cosine_ops" + ) + .as_str(), + ) + .await?; + } + } + + // Required to set the default value for a not null column being added, but we want to remove + // it right after + let mut transaction = pool.begin().await?; + transaction.execute("ALTER TABLE pgml.collections ADD COLUMN IF NOT EXISTS sdk_version text NOT NULL DEFAULT '0.9.2'").await?; + transaction + .execute("ALTER TABLE pgml.collections ALTER COLUMN sdk_version DROP DEFAULT") + .await?; + transaction.commit().await?; + Ok(()) +} diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 87e632b34..e9c76e0ae 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -688,9 +688,9 @@ impl Pipeline { transaction .execute( query_builder!( - queries::CREATE_INDEX_USING_IVFFLAT, + queries::CREATE_INDEX_USING_HNSW, "", - "vector_index", + "hnsw_vector_index", &embeddings_table_name, "embedding vector_cosine_ops" ) diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 31122aac4..02bd985b9 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -8,6 +8,7 @@ CREATE TABLE IF NOT EXISTS pgml.collections ( name text NOT NULL, active BOOLEAN DEFAULT TRUE, project_id int8 NOT NULL REFERENCES pgml.projects ON DELETE CASCADE ON UPDATE CASCADE DEFERRABLE INITIALLY DEFERRED, + sdk_version text; UNIQUE (name) ); "#; @@ -88,8 +89,8 @@ pub const CREATE_INDEX_USING_GIN: &str = r#" CREATE INDEX %d IF NOT EXISTS %s ON %s USING GIN (%d); "#; -pub const CREATE_INDEX_USING_IVFFLAT: &str = r#" -CREATE INDEX %d IF NOT EXISTS %s ON %s USING ivfflat (%d); +pub const CREATE_INDEX_USING_HNSW: &str = r#" +CREATE INDEX %d IF NOT EXISTS %S on %s using hnsw (%d); "#; ///////////////////////////// From f3cbf9fcfdc6762ecc6c0fe2ee616e4835b4b0ea Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Sep 2023 16:34:13 -0700 Subject: [PATCH 02/11] Almost ready for 0.9.2 --- pgml-sdks/pgml/Cargo.lock | 13 ++ pgml-sdks/pgml/Cargo.toml | 3 +- pgml-sdks/pgml/build.rs | 6 +- .../javascript/tests/typescript-tests/test.ts | 10 +- pgml-sdks/pgml/python/pgml/pgml.pyi | 3 +- pgml-sdks/pgml/python/tests/test.py | 12 +- pgml-sdks/pgml/src/builtins.rs | 2 +- pgml-sdks/pgml/src/collection.rs | 2 +- pgml-sdks/pgml/src/lib.rs | 114 +++++++++++++++++- pgml-sdks/pgml/src/migrations/mod.rs | 81 +++++++------ .../pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 20 ++- pgml-sdks/pgml/src/pipeline.rs | 80 ++++++++---- pgml-sdks/pgml/src/queries.rs | 2 +- pgml-sdks/pgml/src/query_builder.rs | 38 ++++-- pgml-sdks/pgml/src/query_runner.rs | 9 +- pgml-sdks/pgml/src/types.rs | 3 + 16 files changed, 307 insertions(+), 91 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index dc5b7dada..2faa354f3 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -50,6 +50,17 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +[[package]] +name = "async-recursion" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e97ce7de6cf12de5d7226c73f5ba9811622f4db3a5b91b55c53e987e5f91cba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "async-trait" version = "0.1.71" @@ -1235,6 +1246,7 @@ name = "pgml" version = "0.9.1" dependencies = [ "anyhow", + "async-recursion", "async-trait", "chrono", "futures", @@ -1344,6 +1356,7 @@ version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" dependencies = [ + "anyhow", "cfg-if", "indoc", "libc", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index b3d15786a..7a0b23c5d 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -20,7 +20,7 @@ serde_json = "1.0.9" anyhow = "1.0.9" tokio = { version = "1.28.2", features = [ "macros" ] } chrono = "0.4.9" -pyo3 = { version = "0.18.3", optional = true, features = ["extension-module"] } +pyo3 = { version = "0.18.3", optional = true, features = ["extension-module", "anyhow"] } pyo3-asyncio = { version = "0.18", features = ["attributes", "tokio-runtime"], optional = true } neon = { version = "0.10", optional = true, default-features = false, features = ["napi-6", "promise-api", "channel-api"] } itertools = "0.10.5" @@ -36,6 +36,7 @@ tracing-subscriber = { version = "0.3.17", features = ["json"] } indicatif = "0.17.6" serde = "1.0.181" futures = "0.3.28" +async-recursion = "1.0.4" [features] default = [] diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 656db9886..77d111b0f 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -3,14 +3,16 @@ use std::fs::OpenOptions; use std::io::Write; const ADDITIONAL_DEFAULTS_FOR_PYTHON: &[u8] = br#" -def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +async def migrate() -> None Json = Any DateTime = int "#; const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" -export function js_init_logger(level?: string, format?: string): void; +export function init_logger(level?: string, format?: string): void; +export function migrate(): Promise; export type Json = { [key: string]: any }; export type DateTime = Date; diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index f4895edf4..19e2373d4 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -10,7 +10,7 @@ import pgml from "../../index.js"; //////////////////////////////////////////////////////////////////////////////////// const LOG_LEVEL = process.env.LOG_LEVEL ? process.env.LOG_LEVEL : "ERROR"; -pgml.js_init_logger(LOG_LEVEL); +pgml.init_logger(LOG_LEVEL); const generate_dummy_documents = (count: number) => { let docs = []; @@ -220,3 +220,11 @@ it("can delete documents", async () => { await collection.archive(); }); + +/////////////////////////////////////////////////// +// Test migrations //////////////////////////////// +/////////////////////////////////////////////////// + +it("can migrate", async () => { + await pgml.migrate(); +}); diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index 9ef3103be..9b1df22d1 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -1,5 +1,6 @@ -def py_init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +def init_logger(level: Optional[str] = "", format: Optional[str] = "") -> None +async def migrate() -> None Json = Any DateTime = int diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index a355b27a8..7b369d433 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -19,7 +19,7 @@ print("No DATABASE_URL environment variable found. Please set one") exit(1) -pgml.py_init_logger() +pgml.init_logger() def generate_dummy_documents(count: int) -> List[Dict[str, Any]]: @@ -250,6 +250,16 @@ async def test_delete_documents(): await collection.archive() +################################################### +## Migration tests ################################ +################################################### + + +@pytest.mark.asyncio +async def test_migrate(): + await pgml.migrate() + + ################################################### ## Test with multiprocessing ###################### ################################################### diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 7dd887a34..db023b951 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -98,7 +98,7 @@ mod tests { async fn can_query() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let builtins = Builtins::new(None); - let query = "SELECT 10"; + let query = "SELECT * from pgml.collections"; let results = builtins.query(query).fetch_all().await?; assert!(results.as_array().is_some()); Ok(()) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 2f76ab1b9..9dd6bf95d 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -210,7 +210,7 @@ impl Collection { .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str()) .await?; - let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") + let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") .bind(&self.name) .bind(project_id) .bind(crate::SDK_VERSION) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index e6a4868f3..48fcf815f 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -12,11 +12,11 @@ use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; -mod migrations; mod builtins; mod collection; mod filter_builder; mod languages; +pub mod migrations; mod model; pub mod models; mod pipeline; @@ -133,10 +133,20 @@ fn init_logger(level: Option, format: Option) -> pyo3::PyResult< Ok(()) } +#[cfg(feature = "python")] +#[pyo3::prelude::pyfunction] +fn migrate(py: pyo3::Python) -> pyo3::PyResult<&pyo3::PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + migrations::migrate().await?; + Ok(()) + }) +} + #[cfg(feature = "python")] #[pyo3::pymodule] fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_function(pyo3::wrap_pyfunction!(init_logger, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(migrate, m)?)?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -158,10 +168,30 @@ fn init_logger( ().into_js_result(&mut cx) } +#[cfg(feature = "javascript")] +fn migrate( + mut cx: neon::context::FunctionContext, +) -> neon::result::JsResult { + use neon::prelude::*; + use rust_bridge::javascript::IntoJsResult; + let channel = cx.channel(); + let (deferred, promise) = cx.promise(); + deferred + .try_settle_with(&channel, move |mut cx| { + let runtime = crate::get_or_set_runtime(); + let x = runtime.block_on(migrations::migrate()); + let x = x.expect("Error running migration"); + x.into_js_result(&mut cx) + }) + .expect("Error sending js"); + Ok(promise) +} + #[cfg(feature = "javascript")] #[neon::main] fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { cx.export_function("init_logger", init_logger)?; + cx.export_function("migrate", migrate)?; cx.export_function("newCollection", collection::CollectionJavascript::new)?; cx.export_function("newModel", model::ModelJavascript::new)?; cx.export_function("newSplitter", splitter::SplitterJavascript::new)?; @@ -263,6 +293,46 @@ mod tests { Ok(()) } + #[sqlx::test] + async fn can_specify_custom_hnsw_parameters_for_pipelines() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cschpfp_0", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "hnsw": { + "m": 100, + "ef_construction": 200 + } + }) + .into(), + ), + ); + let collection_name = "test_r_c_cschpfp_1"; + let mut collection = Collection::new(collection_name, None); + collection.add_pipeline(&mut pipeline).await?; + let full_embeddings_table_name = pipeline.create_or_get_embeddings_table().await?; + let embeddings_table_name = full_embeddings_table_name.split(".").collect::>()[1]; + let pool = get_or_initialize_pool(&None).await?; + let results: Vec<(String, String)> = sqlx::query_as(&query_builder!( + "select indexname, indexdef from pg_indexes where tablename = '%d' and schemaname = '%d'", + embeddings_table_name, + collection_name + )).fetch_all(&pool).await?; + let names = results.iter().map(|(name, _)| name).collect::>(); + let definitions = results + .iter() + .map(|(_, definition)| definition) + .collect::>(); + assert!(names.contains(&&format!("{}_pipeline_hnsw_vector_index", pipeline.name))); + assert!(definitions.contains(&&format!("CREATE INDEX {}_pipeline_hnsw_vector_index ON {} USING hnsw (embedding vector_cosine_ops) WITH (m='100', ef_construction='200')", pipeline.name, full_embeddings_table_name))); + Ok(()) + } + #[sqlx::test] async fn disable_enable_pipeline() -> anyhow::Result<()> { let model = Model::default(); @@ -492,7 +562,43 @@ mod tests { } #[sqlx::test] - async fn can_filter_documents() -> anyhow::Result<()> { + async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::default(); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", Some(model), Some(splitter), None); + let mut collection = Collection::new("test_r_c_cvswqb_3", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "hnsw": { + "ef_search": 2 + } + }) + .into(), + ), + ) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_filter_vector_search() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); let model = Model::new(None, None, None); let splitter = Splitter::new(None, None); @@ -841,8 +947,8 @@ mod tests { #[sqlx::test] async fn can_filter_and_delete_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let model = Model::new(None, None, None); - let splitter = Splitter::new(None, None); + let model = Model::default(); + let splitter = Splitter::default(); let mut pipeline = Pipeline::new( "test_r_p_cfadd_1", Some(model), diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs index e1418698f..158118453 100644 --- a/pgml-sdks/pgml/src/migrations/mod.rs +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -1,4 +1,4 @@ -use futures::FutureExt; +use futures::{future::BoxFuture, FutureExt}; use itertools::Itertools; use sqlx::PgPool; use tracing::instrument; @@ -8,57 +8,58 @@ use crate::get_or_initialize_pool; #[path = "pgml--0.9.1--0.9.2.rs"] mod pgml091_092; -// There is probably a better way to write these types and bypass the need for the closure pass -// through, but it is proving to be difficult -// We could also probably remove some unnecessary clones in the call_migrate function if I was savy -// enough to reconcile the lifetimes +// There is probably a better way to write this type and the version_migrations variable in the dispatch_migrations function type MigrateFn = - &'static dyn Fn(PgPool, Vec) -> futures::future::BoxFuture<'static, anyhow::Result<()>>; -const VERSION_MIGRATIONS: &'static [(&'static str, MigrateFn)] = - &[("0.9.2", &|p, c| pgml091_092::migrate(p, c).boxed())]; + Box) -> BoxFuture<'static, anyhow::Result> + Send + Sync>; #[instrument] -pub async fn migrate() -> anyhow::Result<()> { - let pool = get_or_initialize_pool(&None).await?; - let results: Result, _> = - sqlx::query_as("SELECT version, id FROM pgml.collections") - .fetch_all(&pool) - .await; - match results { - Ok(collections) => { - let collections = collections.into_iter().into_group_map(); - for (version, collection_ids) in collections.into_iter() { - call_migrate(pool.clone(), version, collection_ids).await? +pub fn migrate() -> BoxFuture<'static, anyhow::Result<()>> { + async move { + let pool = get_or_initialize_pool(&None).await?; + let results: Result, _> = + sqlx::query_as("SELECT sdk_version, id FROM pgml.collections") + .fetch_all(&pool) + .await; + match results { + Ok(collections) => { + dispatch_migrations(pool, collections).await?; + Ok(()) } - Ok(()) - } - Err(error) => { - let morphed_error = error - .as_database_error() - .map(|e| e.code().map(|c| c.to_string())); - if let Some(Some(db_error_code)) = morphed_error { - if db_error_code == "42703" { - pgml091_092::migrate(pool, vec![]).await + Err(error) => { + let morphed_error = error + .as_database_error() + .map(|e| e.code().map(|c| c.to_string())); + if let Some(Some(db_error_code)) = morphed_error { + if db_error_code == "42703" { + pgml091_092::migrate(pool, vec![]).await?; + migrate().await?; + Ok(()) + } else { + anyhow::bail!(error) + } } else { anyhow::bail!(error) } - } else { - anyhow::bail!(error) } } } + .boxed() } -async fn call_migrate( - pool: PgPool, - version: String, - collection_ids: Vec, -) -> anyhow::Result<()> { - let position = VERSION_MIGRATIONS.iter().position(|(v, _)| v == &version); - if let Some(p) = position { - // We run each migration in order that needs to be ran for the collections - for (_, callback) in VERSION_MIGRATIONS.iter().skip(p + 1) { - callback(pool.clone(), collection_ids.clone()).await? +async fn dispatch_migrations(pool: PgPool, collections: Vec<(String, i64)>) -> anyhow::Result<()> { + // The version of the SDK that the migration was written for, and the migration function + let version_migrations: [(&'static str, MigrateFn); 1] = + [("0.9.1", Box::new(|p, c| pgml091_092::migrate(p, c).boxed()))]; + + let mut collections = collections.into_iter().into_group_map(); + for (version, migration) in version_migrations.into_iter() { + if let Some(collection_ids) = collections.remove(version) { + let new_version = migration(pool.clone(), collection_ids.clone()).await?; + if let Some(new_collection_ids) = collections.get_mut(&new_version) { + new_collection_ids.extend(collection_ids); + } else { + collections.insert(new_version, collection_ids); + } } } Ok(()) diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs index adcc18b3c..63ce68bb2 100644 --- a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -4,7 +4,7 @@ use sqlx::PgPool; use tracing::instrument; #[instrument(skip(pool))] -pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { +pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result { let collection_names: Vec = sqlx::query_scalar("SELECT name FROM pgml.collections") .fetch_all(&pool) .await?; @@ -16,18 +16,30 @@ pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { .await?; for pipeline_name in pipeline_names { let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); + let index_name = format!("{}_pipeline_hnsw_vector_index", pipeline_name); pool.execute( query_builder!( queries::CREATE_INDEX_USING_HNSW, "", - "hnsw_vector_index", + index_name, &table_name, - "embedding vector_cosine_ops" + "embedding vector_cosine_ops", + "" ) .as_str(), ) .await?; } + // We can get rid of the old IVFFlat index now. There was a bug where we named it the same + // thing no matter what, so we only need to remove one index. + pool.execute( + query_builder!( + "DROP INDEX CONCURRENTLY IF EXISTS %s.vector_index", + collection_name + ) + .as_str(), + ) + .await?; } // Required to set the default value for a not null column being added, but we want to remove @@ -38,5 +50,5 @@ pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result<()> { .execute("ALTER TABLE pgml.collections ALTER COLUMN sdk_version DROP DEFAULT") .await?; transaction.commit().await?; - Ok(()) + Ok("0.9.2".to_string()) } diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index e9c76e0ae..dceff4270 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -14,7 +14,7 @@ use crate::{ models, queries, query_builder, remote_embeddings::build_remote_embeddings, splitter::Splitter, - types::{DateTime, Json}, + types::{DateTime, Json, TryToNumeric}, utils, }; @@ -591,19 +591,16 @@ impl Pipeline { } #[instrument(skip(self))] - async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { + pub(crate) async fn create_or_get_embeddings_table(&mut self) -> anyhow::Result { self.verify_in_database(false).await?; let pool = self.get_pool().await?; - let embeddings_table_name = format!( - "{}.{}_embeddings", - &self - .project_info - .as_ref() - .context("Pipeline must have project info to get the embeddings table name")? - .name, - self.name - ); + let collection_name = &self + .project_info + .as_ref() + .context("Pipeline must have project info to get the embeddings table name")? + .name; + let embeddings_table_name = format!("{}.{}_embeddings", collection_name, self.name); // Notice that we actually check for existence of the table in the database instead of // blindly creating it with `CREATE TABLE IF NOT EXISTS`. This is because we want to avoid @@ -623,9 +620,9 @@ impl Pipeline { .as_ref() .context("Pipeline must be verified to create embeddings table")?; - // Remove the stored name from the parameters - let mut parameters = model.parameters.clone(); - parameters + // Remove the stored name from the model parameters + let mut model_parameters = model.parameters.clone(); + model_parameters .as_object_mut() .context("Model parameters must be an object")? .remove("name"); @@ -635,13 +632,13 @@ impl Pipeline { let embedding: (Vec,) = sqlx::query_as( "SELECT embedding from pgml.embed(transformer => $1, text => 'Hello, World!', kwargs => $2) as embedding") .bind(&model.name) - .bind(parameters) + .bind(model_parameters) .fetch_one(&pool).await?; embedding.0.len() as i64 } t => { let remote_embeddings = - build_remote_embeddings(t.to_owned(), &model.name, &model.parameters)?; + build_remote_embeddings(t.to_owned(), &model.name, &model_parameters)?; remote_embeddings.get_embedding_size().await? } }; @@ -661,38 +658,65 @@ impl Pipeline { )) .execute(&mut *transaction) .await?; + let index_name = format!("{}_pipeline_created_at_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX, "", - "created_at_index", + index_name, &embeddings_table_name, "created_at" ) .as_str(), ) .await?; + let index_name = format!("{}_pipeline_chunk_id_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX, "", - "chunk_id_index", + index_name, &embeddings_table_name, "chunk_id" ) .as_str(), ) .await?; + // See: https://github.com/pgvector/pgvector + let (m, ef_construction) = match &self.parameters { + Some(p) => { + let m = if !p["hnsw"]["m"].is_null() { + p["hnsw"]["m"] + .try_to_u64() + .context("hnsw.m must be an integer")? + } else { + 16 + }; + let ef_construction = if !p["hnsw"]["ef_construction"].is_null() { + p["hnsw"]["ef_construction"] + .try_to_u64() + .context("hnsw.ef_construction must be an integer")? + } else { + 64 + }; + (m, ef_construction) + } + None => (16, 64), + }; + let index_with_parameters = + format!("WITH (m = {}, ef_construction = {})", m, ef_construction); + let index_name = format!("{}_pipeline_hnsw_vector_index", self.name); transaction .execute( query_builder!( queries::CREATE_INDEX_USING_HNSW, "", - "hnsw_vector_index", + index_name, &embeddings_table_name, - "embedding vector_cosine_ops" + "embedding vector_cosine_ops", + index_with_parameters ) .as_str(), ) @@ -788,11 +812,23 @@ impl Pipeline { project_info: &ProjectInfo, conn: &mut PgConnection, ) -> anyhow::Result<()> { + let pipelines_table_name = format!("{}.pipelines", project_info.name); sqlx::query(&query_builder!( queries::CREATE_PIPELINES_TABLE, - &format!("{}.pipelines", project_info.name) + pipelines_table_name )) - .execute(conn) + .execute(&mut *conn) + .await?; + conn.execute( + query_builder!( + queries::CREATE_INDEX, + "", + "pipeline_name_index", + pipelines_table_name, + "name" + ) + .as_str(), + ) .await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 02bd985b9..254b92248 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -90,7 +90,7 @@ CREATE INDEX %d IF NOT EXISTS %s ON %s USING GIN (%d); "#; pub const CREATE_INDEX_USING_HNSW: &str = r#" -CREATE INDEX %d IF NOT EXISTS %S on %s using hnsw (%d); +CREATE INDEX %d IF NOT EXISTS %s on %s using hnsw (%d) %d; "#; ///////////////////////////// diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index a759cc7e4..59881af64 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -13,7 +13,7 @@ use crate::{ models, pipeline::Pipeline, remote_embeddings::build_remote_embeddings, - types::{IntoTableNameAndSchema, Json, SIden}, + types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, Collection, }; @@ -120,6 +120,7 @@ impl QueryBuilder { // Save these in case of failure self.pipeline = Some(pipeline.clone()); self.query_string = Some(query.to_owned()); + self.query_parameters = query_parameters.clone(); let query_parameters = query_parameters.unwrap_or_default().0; let embeddings_table_name = @@ -218,13 +219,34 @@ impl QueryBuilder { pub async fn fetch_all(mut self) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.collection.database_url).await?; - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); + let query_parameters = self.query_parameters.unwrap_or_default(); + let result: Result, _> = - sqlx::query_as_with(&sql, values).fetch_all(&pool).await; + if !query_parameters["hnsw"]["ef_search"].is_null() { + let mut transaction = pool.begin().await?; + let ef_search = query_parameters["hnsw"]["ef_search"] + .try_to_i64() + .context("ef_search must be an integer")?; + sqlx::query("SET LOCAL hnsw.ef_search = $1") + .bind(ef_search) + .execute(&mut *transaction) + .await?; + let (sql, values) = self + .query + .clone() + .with(self.with.clone()) + .build_sqlx(PostgresQueryBuilder); + let results = sqlx::query_as_with(&sql, values).fetch_all(&mut *transaction).await; + transaction.commit().await?; + results + } else { + let (sql, values) = self + .query + .clone() + .with(self.with.clone()) + .build_sqlx(PostgresQueryBuilder); + sqlx::query_as_with(&sql, values).fetch_all(&pool).await + }; match result { Ok(r) => Ok(r), @@ -249,8 +271,6 @@ impl QueryBuilder { return Err(anyhow::anyhow!(e)); } - let query_parameters = self.query_parameters.to_owned().unwrap_or_default(); - let remote_embeddings = build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; let mut embeddings = remote_embeddings diff --git a/pgml-sdks/pgml/src/query_runner.rs b/pgml-sdks/pgml/src/query_runner.rs index ff8f3fa8f..623a09662 100644 --- a/pgml-sdks/pgml/src/query_runner.rs +++ b/pgml-sdks/pgml/src/query_runner.rs @@ -46,9 +46,12 @@ impl QueryRunner { let pool = get_or_initialize_pool(&self.database_url).await?; self.query = format!("SELECT json_agg(j) FROM ({}) j", self.query); let query = self.build_query(); - let results = query.fetch_all(&pool).await?; - let results = results.get(0).unwrap().get::(0); - Ok(Json(results)) + let results = query.fetch_one(&pool).await?; + let results = results.try_get::(0); + match results { + Ok(r) => Ok(Json(r)), + _ => Ok(Json(serde_json::json!([]))), + } } pub async fn execute(self) -> anyhow::Result<()> { diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index d3d1ce306..f7bd4cfd1 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -44,6 +44,9 @@ impl Serialize for Json { pub(crate) trait TryToNumeric { fn try_to_u64(&self) -> anyhow::Result; + fn try_to_i64(&self) -> anyhow::Result { + self.try_to_u64().map(|u| u as i64) + } } impl TryToNumeric for serde_json::Value { From 95789a52b05d0c9cc9a9efae64760d8f0a2ab929 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 5 Sep 2023 09:10:11 -0700 Subject: [PATCH 03/11] Working HNSW --- pgml-sdks/pgml/src/lib.rs | 46 ++++++++++++++++-- pgml-sdks/pgml/src/query_builder.rs | 74 +++++++++++++++++++++-------- 2 files changed, 98 insertions(+), 22 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 48fcf815f..d77f59e9c 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -567,12 +567,52 @@ mod tests { internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", Some(model), Some(splitter), None); - let mut collection = Collection::new("test_r_c_cvswqb_3", None); + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); + let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example - let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "hnsw": { + "ef_search": 2 + } + }) + .into(), + ), + ) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::new( + Some("text-embedding-ada-002".to_string()), + Some("openai".to_string()), + None, + ); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", Some(model), Some(splitter), None); + let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", None, None, None); collection .upsert_documents(generate_dummy_documents(3)) .await?; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 59881af64..184e229b7 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -6,12 +6,14 @@ use sea_query::{ }; use sea_query_binder::SqlxBinder; use std::borrow::Cow; +use tracing::instrument; use crate::{ filter_builder, get_or_initialize_pool, model::ModelRuntime, models, pipeline::Pipeline, + query_builder, remote_embeddings::build_remote_embeddings, types::{IntoTableNameAndSchema, Json, SIden, TryToNumeric}, Collection, @@ -46,11 +48,13 @@ impl QueryBuilder { } } + #[instrument(skip(self))] pub fn limit(mut self, limit: u64) -> Self { self.query.limit(limit); self } + #[instrument(skip(self))] pub fn filter(mut self, mut filter: Json) -> Self { let filter = filter .0 @@ -65,12 +69,14 @@ impl QueryBuilder { self } + #[instrument(skip(self))] fn filter_metadata(mut self, filter: serde_json::Value) -> Self { let filter = filter_builder::FilterBuilder::new(filter, "documents", "metadata").build(); self.query.cond_where(filter); self } + #[instrument(skip(self))] fn filter_full_text(mut self, mut filter: serde_json::Value) -> Self { let filter = filter .as_object_mut() @@ -111,6 +117,7 @@ impl QueryBuilder { self } + #[instrument(skip(self))] pub fn vector_recall( mut self, query: &str, @@ -122,7 +129,12 @@ impl QueryBuilder { self.query_string = Some(query.to_owned()); self.query_parameters = query_parameters.clone(); - let query_parameters = query_parameters.unwrap_or_default().0; + let mut query_parameters = query_parameters.unwrap_or_default().0; + // If they did set hnsw, remove it before we pass it to the model + query_parameters + .as_object_mut() + .expect("Query parameters must be a Json object") + .remove("hnsw"); let embeddings_table_name = format!("{}.{}_embeddings", self.collection.name, pipeline.name); @@ -216,10 +228,17 @@ impl QueryBuilder { self } + #[instrument(skip(self))] pub async fn fetch_all(mut self) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.collection.database_url).await?; - let query_parameters = self.query_parameters.unwrap_or_default(); + let mut query_parameters = self.query_parameters.unwrap_or_default(); + + let (sql, values) = self + .query + .clone() + .with(self.with.clone()) + .build_sqlx(PostgresQueryBuilder); let result: Result, _> = if !query_parameters["hnsw"]["ef_search"].is_null() { @@ -227,24 +246,15 @@ impl QueryBuilder { let ef_search = query_parameters["hnsw"]["ef_search"] .try_to_i64() .context("ef_search must be an integer")?; - sqlx::query("SET LOCAL hnsw.ef_search = $1") - .bind(ef_search) + sqlx::query(&query_builder!("SET LOCAL hnsw.ef_search = %d", ef_search)) .execute(&mut *transaction) .await?; - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); - let results = sqlx::query_as_with(&sql, values).fetch_all(&mut *transaction).await; + let results = sqlx::query_as_with(&sql, values) + .fetch_all(&mut *transaction) + .await; transaction.commit().await?; results } else { - let (sql, values) = self - .query - .clone() - .with(self.with.clone()) - .build_sqlx(PostgresQueryBuilder); sqlx::query_as_with(&sql, values).fetch_all(&pool).await }; @@ -252,6 +262,8 @@ impl QueryBuilder { Ok(r) => Ok(r), Err(e) => match e.as_database_error() { Some(d) => { + println!("THE ERORR: {:?}", d); + println!("THE ERROR CODE: {:?}", d.code()); if d.code() == Some(Cow::from("XX000")) { // Explicitly get and set the model let project_info = self.collection.get_project_info().await?; @@ -266,11 +278,18 @@ impl QueryBuilder { .as_ref() .context("Pipeline must be verified to perform vector search with remote embeddings")?; + println!("THE MODEL: {:?}", model); + // If the model runtime is python, the error was not caused by an unsupported runtime if model.runtime == ModelRuntime::Python { return Err(anyhow::anyhow!(e)); } + let hnsw_parameters = query_parameters + .as_object_mut() + .context("Query parameters must be a Json object")? + .remove("hnsw"); + let remote_embeddings = build_remote_embeddings(model.runtime, &model.name, &query_parameters)?; let mut embeddings = remote_embeddings @@ -308,10 +327,27 @@ impl QueryBuilder { .clone() .with(with_clause) .build_sqlx(PostgresQueryBuilder); - sqlx::query_as_with(&sql, values) - .fetch_all(&pool) - .await - .map_err(|e| anyhow::anyhow!(e)) + + if let Some(parameters) = hnsw_parameters { + let mut transaction = pool.begin().await?; + let ef_search = parameters["ef_search"] + .try_to_i64() + .context("ef_search must be an integer")?; + sqlx::query(&query_builder!( + "SET LOCAL hnsw.ef_search = %d", + ef_search + )) + .execute(&mut *transaction) + .await?; + let results = sqlx::query_as_with(&sql, values) + .fetch_all(&mut *transaction) + .await; + transaction.commit().await?; + results + } else { + sqlx::query_as_with(&sql, values).fetch_all(&pool).await + } + .map_err(|e| anyhow::anyhow!(e)) } else { Err(anyhow::anyhow!(e)) } From 3044bb83f4d6f226246412db69a8933be3b0300a Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:52:39 -0700 Subject: [PATCH 04/11] Cleaned up and ready to go --- .../javascript/tests/typescript-tests/test.ts | 46 ++++++++++ pgml-sdks/pgml/python/pgml/pgml.pyi | 86 ------------------- pgml-sdks/pgml/python/tests/test.py | 38 ++++++++ pgml-sdks/pgml/src/lib.rs | 12 ++- .../pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 34 +++++--- pgml-sdks/pgml/src/query_builder.rs | 4 - 6 files changed, 114 insertions(+), 106 deletions(-) diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index 19e2373d4..c9113a04c 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -143,6 +143,52 @@ it("can vector search with query builder and metadata filtering", async () => { await collection.archive(); }); +it("can vector search with query builder and custom hnsfw ef_search value", async () => { + let model = pgml.newModel(); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline("test_j_p_cvswqbachesv_0", model, splitter); + let collection = pgml.newCollection("test_j_c_cvswqbachesv_0"); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.add_pipeline(pipeline); + let results = await collection + .query() + .vector_recall("Here is some query", pipeline) + .filter({ + hnsw: { + ef_search: 2, + }, + }) + .limit(10) + .fetch_all(); + expect(results).toHaveLength(3); + await collection.archive(); +}); + +it("can vector search with query builder and custom hnsfw ef_search value and remote embeddings", async () => { + let model = pgml.newModel("text-embedding-ada-002", "openai"); + let splitter = pgml.newSplitter(); + let pipeline = pgml.newPipeline( + "test_j_p_cvswqbachesvare_0", + model, + splitter, + ); + let collection = pgml.newCollection("test_j_c_cvswqbachesvare_0"); + await collection.upsert_documents(generate_dummy_documents(3)); + await collection.add_pipeline(pipeline); + let results = await collection + .query() + .vector_recall("Here is some query", pipeline) + .filter({ + hnsw: { + ef_search: 2, + }, + }) + .limit(10) + .fetch_all(); + expect(results).toHaveLength(3); + await collection.archive(); +}); + /////////////////////////////////////////////////// // Test user output facing functions ////////////// /////////////////////////////////////////////////// diff --git a/pgml-sdks/pgml/python/pgml/pgml.pyi b/pgml-sdks/pgml/python/pgml/pgml.pyi index 9b1df22d1..f043afd52 100644 --- a/pgml-sdks/pgml/python/pgml/pgml.pyi +++ b/pgml-sdks/pgml/python/pgml/pgml.pyi @@ -4,89 +4,3 @@ async def migrate() -> None Json = Any DateTime = int - -# Top of file key: A12BECOD! -from typing import List, Dict, Optional, Self, Any - - -class Builtins: - def __init__(self, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self - ... - def query(self, query: str) -> QueryRunner - ... - async def transform(self, task: Json, inputs: List[str], args: Optional[Json] = Any) -> Json - ... - -class Collection: - def __init__(self, name: str, database_url: Optional[str] = "Default set in Rust. Please check the documentation.") -> Self - ... - async def add_pipeline(self, pipeline: Pipeline) -> None - ... - async def remove_pipeline(self, pipeline: Pipeline) -> None - ... - async def enable_pipeline(self, pipeline: Pipeline) -> None - ... - async def disable_pipeline(self, pipeline: Pipeline) -> None - ... - async def upsert_documents(self, documents: List[Json]) -> None - ... - async def get_documents(self, args: Optional[Json] = Any) -> List[Json] - ... - async def delete_documents(self, filter: Json) -> None - ... - async def vector_search(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any, top_k: Optional[int] = 1) -> List[tuple[float, str, Json]] - ... - async def archive(self) -> None - ... - def query(self) -> QueryBuilder - ... - async def get_pipelines(self) -> List[Pipeline] - ... - async def get_pipeline(self, name: str) -> Pipeline - ... - async def exists(self) -> bool - ... - -class Model: - def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", source: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self - ... - -class Pipeline: - def __init__(self, name: str, model: Optional[Model] = Any, splitter: Optional[Splitter] = Any, parameters: Optional[Json] = Any) -> Self - ... - async def get_status(self) -> PipelineSyncData - ... - async def to_dict(self) -> Json - ... - -class QueryBuilder: - def limit(self, limit: int) -> Self - ... - def filter(self, filter: Json) -> Self - ... - def vector_recall(self, query: str, pipeline: Pipeline, query_parameters: Optional[Json] = Any) -> Self - ... - async def fetch_all(self) -> List[tuple[float, str, Json]] - ... - def to_full_string(self) -> str - ... - -class QueryRunner: - async def fetch_all(self) -> Json - ... - async def execute(self) -> None - ... - def bind_string(self, bind_value: str) -> Self - ... - def bind_int(self, bind_value: int) -> Self - ... - def bind_float(self, bind_value: float) -> Self - ... - def bind_bool(self, bind_value: bool) -> Self - ... - def bind_json(self, bind_value: Json) -> Self - ... - -class Splitter: - def __init__(self, name: Optional[str] = "Default set in Rust. Please check the documentation.", parameters: Optional[Json] = Any) -> Self - ... diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 7b369d433..0b1632b0a 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -164,6 +164,44 @@ async def test_can_vector_search_with_query_builder_and_metadata_filtering(): await collection.archive() +@pytest.mark.asyncio +async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value(): + model = pgml.Model() + splitter = pgml.Splitter() + pipeline = pgml.Pipeline("test_p_p_tcvswqbachesv_0", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqbachesv_0") + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.add_pipeline(pipeline) + results = ( + await collection.query() + .vector_recall("Here is some query", pipeline) + .filter({"hnsw": {"ef_search": 2}}) + .limit(10) + .fetch_all() + ) + assert len(results) == 3 + await collection.archive() + + +@pytest.mark.asyncio +async def test_can_vector_search_with_query_builder_and_custom_hnsw_ef_search_value_and_remote_embeddings(): + model = pgml.Model(name="text-embedding-ada-002", source="openai") + splitter = pgml.Splitter() + pipeline = pgml.Pipeline("test_p_p_tcvswqbachesvare_0", model, splitter) + collection = pgml.Collection(name="test_p_c_tcvswqbachesvare_0") + await collection.upsert_documents(generate_dummy_documents(3)) + await collection.add_pipeline(pipeline) + results = ( + await collection.query() + .vector_recall("Here is some query", pipeline) + .filter({"hnsw": {"ef_search": 2}}) + .limit(10) + .fetch_all() + ) + assert len(results) == 3 + await collection.archive() + + ################################################### ## Test user output facing functions ############## ################################################### diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index d77f59e9c..b501b0db3 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -36,7 +36,7 @@ pub use pipeline::Pipeline; pub use splitter::Splitter; // This is use when inserting collections to set the sdk_version used during creation -static SDK_VERSION: &'static str = "0.9.2"; +static SDK_VERSION: &str = "0.9.2"; // Store the database(s) in a global variable so that we can access them from anywhere // This is not necessarily idiomatic Rust, but it is a good way to acomplish what we need @@ -567,7 +567,8 @@ mod tests { internal_init_logger(None, None).ok(); let model = Model::default(); let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); + let mut pipeline = + Pipeline::new("test_r_p_cvswqbachesv_1", Some(model), Some(splitter), None); let mut collection = Collection::new("test_r_c_cvswqbachesv_3", None); collection.add_pipeline(&mut pipeline).await?; @@ -607,7 +608,12 @@ mod tests { None, ); let splitter = Splitter::default(); - let mut pipeline = Pipeline::new("test_r_p_cvswqbachesvare_2", Some(model), Some(splitter), None); + let mut pipeline = Pipeline::new( + "test_r_p_cvswqbachesvare_2", + Some(model), + Some(splitter), + None, + ); let mut collection = Collection::new("test_r_c_cvswqbachesvare_7", None); collection.add_pipeline(&mut pipeline).await?; diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs index 63ce68bb2..165bc6f0e 100644 --- a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -15,20 +15,28 @@ pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result { .fetch_all(&pool) .await?; for pipeline_name in pipeline_names { - let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); - let index_name = format!("{}_pipeline_hnsw_vector_index", pipeline_name); - pool.execute( - query_builder!( - queries::CREATE_INDEX_USING_HNSW, - "", - index_name, - &table_name, - "embedding vector_cosine_ops", - "" + let embeddings_table_name = format!("{}_embeddings", pipeline_name); + let exists: bool = sqlx::query_scalar("SELECT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = $1 and table_schema = $2)") + .bind(embeddings_table_name) + .bind(&collection_name) + .fetch_one(&pool) + .await?; + if exists { + let table_name = format!("{}.{}_embeddings", collection_name, pipeline_name); + let index_name = format!("{}_pipeline_hnsw_vector_index", pipeline_name); + pool.execute( + query_builder!( + queries::CREATE_INDEX_USING_HNSW, + "", + index_name, + &table_name, + "embedding vector_cosine_ops", + "" + ) + .as_str(), ) - .as_str(), - ) - .await?; + .await?; + } } // We can get rid of the old IVFFlat index now. There was a bug where we named it the same // thing no matter what, so we only need to remove one index. diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 184e229b7..f7c02b991 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -262,8 +262,6 @@ impl QueryBuilder { Ok(r) => Ok(r), Err(e) => match e.as_database_error() { Some(d) => { - println!("THE ERORR: {:?}", d); - println!("THE ERROR CODE: {:?}", d.code()); if d.code() == Some(Cow::from("XX000")) { // Explicitly get and set the model let project_info = self.collection.get_project_info().await?; @@ -278,8 +276,6 @@ impl QueryBuilder { .as_ref() .context("Pipeline must be verified to perform vector search with remote embeddings")?; - println!("THE MODEL: {:?}", model); - // If the model runtime is python, the error was not caused by an unsupported runtime if model.runtime == ModelRuntime::Python { return Err(anyhow::anyhow!(e)); From 1ab68683e6bfef06f7633504f0bf906cd2bec672 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 5 Sep 2023 12:22:27 -0700 Subject: [PATCH 05/11] Renaming --- pgml-sdks/pgml/src/lib.rs | 2 +- pgml-sdks/pgml/src/migrations/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index b501b0db3..be2d998b0 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -508,7 +508,7 @@ mod tests { // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(30000)) .await?; let results = collection .query() diff --git a/pgml-sdks/pgml/src/migrations/mod.rs b/pgml-sdks/pgml/src/migrations/mod.rs index 158118453..b67dec8fa 100644 --- a/pgml-sdks/pgml/src/migrations/mod.rs +++ b/pgml-sdks/pgml/src/migrations/mod.rs @@ -71,7 +71,7 @@ mod tests { use crate::internal_init_logger; #[tokio::test] - async fn test_migrate() -> anyhow::Result<()> { + async fn can_migrate() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); migrate().await?; Ok(()) From bbfdcb674ad8eca7b9fa521afc417bc005bfb2cb Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 09:56:26 -0700 Subject: [PATCH 06/11] Updated queries to use hnsw indices --- pgml-sdks/pgml/src/collection.rs | 2 - pgml-sdks/pgml/src/lib.rs | 66 ++++++++++++++++++++++++++--- pgml-sdks/pgml/src/queries.rs | 34 ++++----------- pgml-sdks/pgml/src/query_builder.rs | 53 +++++++---------------- 4 files changed, 83 insertions(+), 72 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 9dd6bf95d..82449a3df 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -926,7 +926,6 @@ impl Collection { queries::EMBED_AND_VECTOR_SEARCH, self.pipelines_table_name, embeddings_table_name, - embeddings_table_name, self.chunks_table_name, self.documents_table_name )) @@ -1012,7 +1011,6 @@ impl Collection { sqlx::query_as(&query_builder!( queries::VECTOR_SEARCH, embeddings_table_name, - embeddings_table_name, self.chunks_table_name, self.documents_table_name )) diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index be2d998b0..4fd02b154 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -467,7 +467,7 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswre_20", None); + let mut collection = Collection::new("test_r_c_cvswre_21", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example @@ -476,7 +476,7 @@ mod tests { .upsert_documents(generate_dummy_documents(3)) .await?; let results = collection - .vector_search("Here is some query", &mut pipeline, None, None) + .vector_search("Here is some query", &mut pipeline, None, Some(10)) .await?; assert!(results.len() == 3); collection.archive().await?; @@ -502,17 +502,70 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswqb_3", None); + let mut collection = Collection::new("test_r_c_cvswqb_4", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqb_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(30000)) + .upsert_documents(generate_dummy_documents(4)) .await?; let results = collection .query() .vector_recall("Here is some query", &mut pipeline, None) + .limit(3) + .fetch_all() + .await?; + assert!(results.len() == 3); + collection.archive().await?; + Ok(()) + } + + #[sqlx::test] + async fn can_vector_search_with_query_builder_and_pass_model_parameters_in_search( + ) -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let model = Model::new( + Some("hkunlp/instructor-base".to_string()), + Some("python".to_string()), + Some(json!({"instruction": "Represent the Wikipedia document for retrieval: "}).into()), + ); + let splitter = Splitter::default(); + let mut pipeline = Pipeline::new( + "test_r_p_cvswqbapmpis_1", + Some(model), + Some(splitter), + Some( + serde_json::json!({ + "full_text_search": { + "active": true, + "configuration": "english" + } + }) + .into(), + ), + ); + let mut collection = Collection::new("test_r_c_cvswqbapmpis_4", None); + collection.add_pipeline(&mut pipeline).await?; + + // Recreate the pipeline to replicate a more accurate example + let mut pipeline = Pipeline::new("test_r_p_cvswqbapmpis_1", None, None, None); + collection + .upsert_documents(generate_dummy_documents(3)) + .await?; + let results = collection + .query() + .vector_recall( + "Here is some query", + &mut pipeline, + Some( + json!({ + "instruction": "Represent the Wikipedia document for retrieval: " + }) + .into(), + ), + ) + .limit(10) .fetch_all() .await?; assert!(results.len() == 3); @@ -543,17 +596,18 @@ mod tests { .into(), ), ); - let mut collection = Collection::new("test_r_c_cvswqbwre_3", None); + let mut collection = Collection::new("test_r_c_cvswqbwre_5", None); collection.add_pipeline(&mut pipeline).await?; // Recreate the pipeline to replicate a more accurate example let mut pipeline = Pipeline::new("test_r_p_cvswqbwre_1", None, None, None); collection - .upsert_documents(generate_dummy_documents(3)) + .upsert_documents(generate_dummy_documents(4)) .await?; let results = collection .query() .vector_recall("Here is some query", &mut pipeline, None) + .limit(3) .fetch_all() .await?; assert!(results.len() == 3); diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 254b92248..b815a2f35 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -188,50 +188,32 @@ embedding AS ( text => $2, kwargs => $3 )::vector AS embedding -), -comparison AS ( - SELECT - chunk_id, - 1 - ( - %s.embedding <=> (SELECT embedding FROM embedding) - ) AS score - FROM - %s ) SELECT - comparison.score, + embeddings.embedding <=> (SELECT embedding FROM embedding) score, chunks.chunk, documents.metadata FROM - comparison - INNER JOIN %s chunks ON chunks.id = comparison.chunk_id + %s embeddings + INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id INNER JOIN %s documents ON documents.id = chunks.document_id ORDER BY - comparison.score DESC + score ASC LIMIT $4; "#; pub const VECTOR_SEARCH: &str = r#" -WITH comparison AS ( - SELECT - chunk_id, - 1 - ( - %s.embedding <=> $1::vector - ) AS score - FROM - %s -) SELECT - comparison.score, + embeddings.embedding <=> $1::vector score, chunks.chunk, documents.metadata FROM - comparison - INNER JOIN %s chunks ON chunks.id = comparison.chunk_id + %s embeddings + INNER JOIN %s chunks ON chunks.id = embeddings.chunk_id INNER JOIN %s documents ON documents.id = chunks.document_id ORDER BY - comparison.score DESC + score ASC LIMIT $2; "#; diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index f7c02b991..410e1b4be 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -178,43 +178,33 @@ impl QueryBuilder { let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); embedding_cte.table_name(Alias::new("embedding")); - // Build the comparison CTE - let mut comparison_cte = Query::select(); - comparison_cte - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .columns([models::EmbeddingIden::ChunkId]) - .expr(Expr::cust( - "1 - (embeddings.embedding <=> (select embedding from embedding)) as score", - )); - let mut comparison_cte = CommonTableExpression::from_select(comparison_cte); - comparison_cte.table_name(Alias::new("comparison")); - // Build the where clause let mut with_clause = WithClause::new(); self.with = with_clause .cte(pipeline_cte) .cte(model_cte) .cte(embedding_cte) - .cte(comparison_cte) .to_owned(); // Build the query self.query + .expr(Expr::cust( + "(embeddings.embedding <=> (SELECT embedding from embedding)) score", + )) .columns([ - (SIden::Str("comparison"), SIden::Str("score")), (SIden::Str("chunks"), SIden::Str("chunk")), (SIden::Str("documents"), SIden::Str("metadata")), ]) - .from(SIden::Str("comparison")) + .from_as( + embeddings_table_name.to_table_tuple(), + SIden::Str("embeddings"), + ) .join_as( JoinType::InnerJoin, self.collection.chunks_table_name.to_table_tuple(), Alias::new("chunks"), Expr::col((SIden::Str("chunks"), SIden::Str("id"))) - .equals((SIden::Str("comparison"), SIden::Str("chunk_id"))), + .equals((SIden::Str("embeddings"), SIden::Str("chunk_id"))), ) .join_as( JoinType::InnerJoin, @@ -223,7 +213,7 @@ impl QueryBuilder { Expr::col((SIden::Str("documents"), SIden::Str("id"))) .equals((SIden::Str("chunks"), SIden::Str("document_id"))), ) - .order_by((SIden::Str("comparison"), SIden::Str("score")), Order::Desc); + .order_by(SIden::Str("score"), Order::Asc); self } @@ -296,27 +286,14 @@ impl QueryBuilder { .await?; let embedding = std::mem::take(&mut embeddings[0]); - // Explicit drop required here or we can't borrow the pipeline immutably - drop(remote_embeddings); - let embeddings_table_name = - format!("{}.{}_embeddings", self.collection.name, pipeline.name); - - let mut comparison_cte = Query::select(); - comparison_cte - .from_as( - embeddings_table_name.to_table_tuple(), - SIden::Str("embeddings"), - ) - .columns([models::EmbeddingIden::ChunkId]) - .expr(Expr::cust_with_values( - "1 - (embeddings.embedding <=> $1::vector) as score", - [embedding], - )); + let mut embedding_cte = Query::select(); + embedding_cte + .expr(Expr::cust_with_values("$1::vector embedding", [embedding])); - let mut comparison_cte = CommonTableExpression::from_select(comparison_cte); - comparison_cte.table_name(Alias::new("comparison")); + let mut embedding_cte = CommonTableExpression::from_select(embedding_cte); + embedding_cte.table_name(Alias::new("embedding")); let mut with_clause = WithClause::new(); - with_clause.cte(comparison_cte); + with_clause.cte(embedding_cte); let (sql, values) = self .query From bc33baa9c9187e38c3b0690d4de7a419269c35d3 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:01:47 -0700 Subject: [PATCH 07/11] Updated score to be 1 - score --- pgml-sdks/pgml/src/collection.rs | 5 +++++ pgml-sdks/pgml/src/query_builder.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 82449a3df..c4b3e4cff 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -968,6 +968,11 @@ impl Collection { .await } } + .map(|r| { + r.into_iter() + .map(|(score, id, metadata)| (1. - score, id, metadata)) + .collect() + }) } #[instrument(skip(self, pool))] diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 410e1b4be..98fbe104a 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -327,7 +327,7 @@ impl QueryBuilder { } None => Err(anyhow::anyhow!(e)), }, - } + }.map(|r| r.into_iter().map(|(score, id, metadata)| (1. - score, id, metadata)).collect()) } // This is mostly so our SDKs in other languages have some way to debug From a36198f97f2efd9964f5167bfc38d378f3ef3aab Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:18:05 -0700 Subject: [PATCH 08/11] Cleaned up examples --- .../javascript/examples/extractive_question_answering.js | 1 - .../javascript/examples/summarizing_question_answering.js | 2 -- .../pgml/python/examples/summarizing_question_answering.py | 5 +---- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js index fac0925ff..f70bf26b4 100644 --- a/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/extractive_question_answering.js @@ -1,7 +1,6 @@ const pgml = require("pgml"); require("dotenv").config(); -pgml.js_init_logger(); const main = async () => { // Initialize the collection diff --git a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js index a5e5fe19b..f779cde60 100644 --- a/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js +++ b/pgml-sdks/pgml/javascript/examples/summarizing_question_answering.js @@ -1,8 +1,6 @@ const pgml = require("pgml"); require("dotenv").config(); -pgml.js_init_logger(); - const main = async () => { // Initialize the collection const collection = pgml.newCollection("my_javascript_sqa_collection"); diff --git a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py index 4c291aac0..3008b31a9 100644 --- a/pgml-sdks/pgml/python/examples/summarizing_question_answering.py +++ b/pgml-sdks/pgml/python/examples/summarizing_question_answering.py @@ -1,4 +1,4 @@ -from pgml import Collection, Model, Splitter, Pipeline, Builtins, py_init_logger +from pgml import Collection, Model, Splitter, Pipeline, Builtins import json from datasets import load_dataset from time import time @@ -7,9 +7,6 @@ import asyncio -py_init_logger() - - async def main(): load_dotenv() console = Console() From 1d0232e0b388d6add75b44135bdd22b769b83413 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:34:41 -0700 Subject: [PATCH 09/11] Added dependency on pgvector 0.5.0 and above for 0.9.2 migration --- pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs index 165bc6f0e..85c5165bb 100644 --- a/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs +++ b/pgml-sdks/pgml/src/migrations/pgml--0.9.1--0.9.2.rs @@ -5,6 +5,17 @@ use tracing::instrument; #[instrument(skip(pool))] pub async fn migrate(pool: PgPool, _: Vec) -> anyhow::Result { + pool.execute("ALTER EXTENSION vector UPDATE").await?; + let version: String = + sqlx::query_scalar("SELECT extversion FROM pg_extension WHERE extname = 'vector'") + .fetch_one(&pool) + .await?; + let value = version.split(".").collect::>()[1].parse::()?; + anyhow::ensure!( + value >= 5, + "Vector extension must be at least version 0.5.0" + ); + let collection_names: Vec = sqlx::query_scalar("SELECT name FROM pgml.collections") .fetch_all(&pool) .await?; From 0c3988acb4fc7ef7e288d7ecf591093b754c3a87 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:58:13 -0700 Subject: [PATCH 10/11] Updated README --- pgml-sdks/pgml/javascript/README.md | 18 ++++++++++++++++++ pgml-sdks/pgml/python/README.md | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/pgml-sdks/pgml/javascript/README.md b/pgml-sdks/pgml/javascript/README.md index de4acede9..0439e7a93 100644 --- a/pgml-sdks/pgml/javascript/README.md +++ b/pgml-sdks/pgml/javascript/README.md @@ -519,6 +519,24 @@ const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { await collection.add_pipeline(pipeline) ``` +### Configuring HNSW Indexing Parameters + +Our SDK utilizes [pgvector](https://github.com/pgvector/pgvector) for storing vectors and performing recall. We use HNSW indexing as it is the most performant mix of performance and recall. + +Our SDK allows for configuration of `m` (the maximum number of connections per layer (16 by default)) and `ef_construction` (the size of the dynamic candidate list when constructing the graph (64 by default)) per pipeline. + +```javascript +const model = pgml.newModel() +const splitter = pgml.newSplitter() +const pipeline = pgml.newPipeline("test_pipeline", model, splitter, { + hnsw: { + m: 100, + ef_construction: 200 + } +}) +await collection.add_pipeline(pipeline) +``` + ### Searching with Pipelines Pipelines are a required argument when performing vector search. After a Pipeline has been added to a Collection, the Model and Splitter can be omitted when instantiating it. diff --git a/pgml-sdks/pgml/python/README.md b/pgml-sdks/pgml/python/README.md index a05c184ce..9eb69e4e8 100644 --- a/pgml-sdks/pgml/python/README.md +++ b/pgml-sdks/pgml/python/README.md @@ -530,6 +530,24 @@ pipeline = Pipeline("test_pipeline", model, splitter, { await collection.add_pipeline(pipeline) ``` +### Configuring HNSW Indexing Parameters + +Our SDK utilizes [pgvector](https://github.com/pgvector/pgvector) for storing vectors and performing recall. We use HNSW indexing as it is the most performant mix of performance and recall. + +Our SDK allows for configuration of `m` (the maximum number of connections per layer (16 by default)) and `ef_construction` (the size of the dynamic candidate list when constructing the graph (64 by default)) per pipeline. + +```python +model = Model() +splitter = Splitter() +pipeline = Pipeline("test_pipeline", model, splitter, { + "hnsw": { + "m": 100, + "ef_construction": 200 + } +}) +await collection.add_pipeline(pipeline) +``` + ### Searching with Pipelines Pipelines are a required argument when performing vector search. After a Pipeline has been added to a Collection, the Model and Splitter can be omitted when instantiating it. From e078911d4a6f1bc06d5d2fcf4e24cd9d622d8bb8 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 6 Sep 2023 11:07:39 -0700 Subject: [PATCH 11/11] Removed unnecessary dependencies --- pgml-sdks/pgml/Cargo.lock | 12 ------------ pgml-sdks/pgml/Cargo.toml | 1 - 2 files changed, 13 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 2faa354f3..f68e47b68 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -50,17 +50,6 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" -[[package]] -name = "async-recursion" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e97ce7de6cf12de5d7226c73f5ba9811622f4db3a5b91b55c53e987e5f91cba" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.28", -] - [[package]] name = "async-trait" version = "0.1.71" @@ -1246,7 +1235,6 @@ name = "pgml" version = "0.9.1" dependencies = [ "anyhow", - "async-recursion", "async-trait", "chrono", "futures", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 7a0b23c5d..ca7782fd0 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -36,7 +36,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json"] } indicatif = "0.17.6" serde = "1.0.181" futures = "0.3.28" -async-recursion = "1.0.4" [features] default = []