From 28eff549af790730b302775d0272b9539c4e8e56 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 19 Jun 2024 07:32:08 -0700 Subject: [PATCH] Batch upsert documents --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/Cargo.toml | 2 +- pgml-sdks/pgml/src/batch.rs | 101 +++++++++++++++++++++++++++++++ pgml-sdks/pgml/src/collection.rs | 56 ++++++++++++----- pgml-sdks/pgml/src/lib.rs | 4 ++ 5 files changed, 147 insertions(+), 18 deletions(-) create mode 100644 pgml-sdks/pgml/src/batch.rs diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 784b528a7..3e595dfec 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1590,7 +1590,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "1.0.4" +version = "1.2.0" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 0a190eaf4..74b2e1f62 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "1.1.0" +version = "1.2.0" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/src/batch.rs b/pgml-sdks/pgml/src/batch.rs new file mode 100644 index 000000000..8eccb5511 --- /dev/null +++ b/pgml-sdks/pgml/src/batch.rs @@ -0,0 +1,101 @@ +//! Upsert documents in batches. + +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + +use tracing::instrument; + +use crate::{types::Json, Collection}; + +#[cfg(feature = "python")] +use crate::{collection::CollectionPython, types::JsonPython}; + +#[cfg(feature = "c")] +use crate::{collection::CollectionC, languages::c::JsonC}; + +/// A batch of documents staged for upsert +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] +pub struct Batch { + collection: Collection, + pub(crate) documents: Vec, + pub(crate) size: i64, + pub(crate) args: Option, +} + +#[cfg_attr(feature = "rust_bridge", alias_methods(new, upsert_documents, finish,))] +impl Batch { + /// Create a new upsert batch. + /// + /// # Arguments + /// + /// * `collection` - The collection to upsert documents to. + /// * `size` - The size of the batch. + /// * `args` - Optional arguments to pass to the upsert operation. + /// + /// # Example + /// + /// ``` + /// use pgml::{Collection, Batch}; + /// + /// let collection = Collection::new("my_collection"); + /// let batch = Batch::new(&collection, 100, None); + /// ``` + pub fn new(collection: &Collection, size: i64, args: Option) -> Self { + Self { + collection: collection.clone(), + args, + documents: Vec::new(), + size, + } + } + + /// Upsert documents into the collection. If the batch is full, save the documents. + /// + /// When using this method, remember to call [finish](Batch::finish) to save any remaining documents + /// in the last batch. + /// + /// # Arguments + /// + /// * `documents` - The documents to upsert. + /// + /// # Example + /// + /// ``` + /// use pgml::{Collection, Batch}; + /// use serde_json::json; + /// + /// let collection = Collection::new("my_collection"); + /// let mut batch = Batch::new(&collection, 100, None); + /// + /// batch.upsert_documents(vec![json!({"id": 1}), json!({"id": 2})]).await?; + /// batch.finish().await?; + /// ``` + #[instrument(skip(self))] + pub async fn upsert_documents(&mut self, documents: Vec) -> anyhow::Result<()> { + for document in documents { + if self.size as usize >= self.documents.len() { + self.collection + .upsert_documents(std::mem::take(&mut self.documents), self.args.clone()) + .await?; + self.documents.clear(); + } + + self.documents.push(document); + } + + Ok(()) + } + + /// Save any remaining documents in the last batch. + #[instrument(skip(self))] + pub async fn finish(&mut self) -> anyhow::Result<()> { + if !self.documents.is_empty() { + self.collection + .upsert_documents(std::mem::take(&mut self.documents), self.args.clone()) + .await?; + } + + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index b0a814b4f..71c11b1f7 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -208,7 +208,7 @@ impl Collection { .all(|c| c.is_alphanumeric() || c.is_whitespace() || c == '-' || c == '_') { anyhow::bail!( - "Name must only consist of letters, numebers, white space, and '-' or '_'" + "Collection name must only consist of letters, numbers, white space, and '-' or '_'" ) } let (pipelines_table_name, documents_table_name) = Self::generate_table_names(name); @@ -264,21 +264,43 @@ impl Collection { } else { let mut transaction = pool.begin().await?; - let project_id: i64 = sqlx::query_scalar("INSERT INTO pgml.projects (name, task) VALUES ($1, 'embedding'::pgml.task) ON CONFLICT (name) DO UPDATE SET task = EXCLUDED.task RETURNING id, task::TEXT") - .bind(&self.name) - .fetch_one(&mut *transaction) - .await?; + let project_id: i64 = sqlx::query_scalar( + " + INSERT INTO pgml.projects ( + name, + task + ) VALUES ( + $1, + 'embedding'::pgml.task + ) ON CONFLICT (name) + DO UPDATE SET + task = EXCLUDED.task + RETURNING id, task::TEXT", + ) + .bind(&self.name) + .fetch_one(&mut *transaction) + .await?; transaction .execute(query_builder!("CREATE SCHEMA IF NOT EXISTS %s", self.name).as_str()) .await?; - let c: models::Collection = sqlx::query_as("INSERT INTO pgml.collections (name, project_id, sdk_version) VALUES ($1, $2, $3) ON CONFLICT (name) DO NOTHING RETURNING *") - .bind(&self.name) - .bind(project_id) - .bind(crate::SDK_VERSION) - .fetch_one(&mut *transaction) - .await?; + let c: models::Collection = sqlx::query_as( + " + INSERT INTO pgml.collections ( + name, + project_id, + sdk_version + ) VALUES ( + $1, $2, $3 + ) ON CONFLICT (name) DO NOTHING + RETURNING *", + ) + .bind(&self.name) + .bind(project_id) + .bind(crate::SDK_VERSION) + .fetch_one(&mut *transaction) + .await?; let collection_database_data = CollectionDatabaseData { id: c.id, @@ -353,23 +375,25 @@ impl Collection { .await?; if exists { - warn!("Pipeline {} already exists not adding", pipeline.name); + warn!("Pipeline {} already exists, not adding", pipeline.name); } else { - // We want to intentially throw an error if they have already added this pipeline + // We want to intentionally throw an error if they have already added this pipeline // as we don't want to casually resync + let mp = MultiProgress::new(); + mp.println(format!("Adding pipeline {}...", pipeline.name))?; + pipeline .verify_in_database(project_info, true, &pool) .await?; - let mp = MultiProgress::new(); - mp.println(format!("Added Pipeline {}, Now Syncing...", pipeline.name))?; + mp.println(format!("Added pipeline {}, now syncing...", pipeline.name))?; // TODO: Revisit this. If the pipeline is added but fails to sync, then it will be "out of sync" with the documents in the table // This is rare, but could happen pipeline .resync(project_info, pool.acquire().await?.as_mut()) .await?; - mp.println(format!("Done Syncing {}\n", pipeline.name))?; + mp.println(format!("Done syncing {}\n", pipeline.name))?; } Ok(()) } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 16ec25ece..0b09f43ea 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -14,6 +14,7 @@ use tokio::runtime::{Builder, Runtime}; use tracing::Level; use tracing_subscriber::FmtSubscriber; +mod batch; mod builtins; #[cfg(any(feature = "python", feature = "javascript"))] mod cli; @@ -40,6 +41,7 @@ mod utils; mod vector_search_query_builder; // Re-export +pub use batch::Batch; pub use builtins::Builtins; pub use collection::Collection; pub use model::Model; @@ -217,6 +219,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -275,6 +278,7 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { "newOpenSourceAI", open_source_ai::OpenSourceAIJavascript::new, )?; + cx.export_function("newBatch", batch::BatchJavascript::new)?; Ok(()) }