From 203951cbacf4a26fe106d741c32bf876f59517f5 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 14 May 2024 10:08:12 -0700 Subject: [PATCH] SDK - Allow parallel batch uploads --- pgml-sdks/pgml/src/collection.rs | 246 ++++++++++++++++++------------- pgml-sdks/pgml/src/lib.rs | 80 ++++++++++ 2 files changed, 224 insertions(+), 102 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index d8bf9e854..27f95813f 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods}; use sea_query::Alias; use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; -use serde_json::json; -use sqlx::Executor; +use serde_json::{json, Value}; use sqlx::PgConnection; +use sqlx::{Executor, Pool, Postgres}; use std::borrow::Cow; use std::collections::HashMap; use std::path::Path; use std::time::SystemTime; use std::time::UNIX_EPOCH; +use tokio::task::JoinSet; use tracing::{instrument, warn}; use walkdir::WalkDir; use crate::debug_sqlx_query; use crate::filter_builder::FilterBuilder; +use crate::pipeline::FieldAction; use crate::search_query_builder::build_search_query; use crate::vector_search_query_builder::build_vector_search_query; use crate::{ @@ -496,13 +498,16 @@ impl Collection { // -> Insert the document // -> Foreach pipeline check if we need to resync the document and if so sync the document // -> Commit the transaction + let mut args = args.unwrap_or_default(); + let args = args.as_object_mut().context("args must be a JSON object")?; + self.verify_in_database(false).await?; let mut pipelines = self.get_pipelines().await?; let pool = get_or_initialize_pool(&self.database_url).await?; - let mut parsed_schemas = vec![]; let project_info = &self.database_data.as_ref().unwrap().project_info; + let mut parsed_schemas = vec![]; for pipeline in &mut pipelines { let parsed_schema = pipeline .get_parsed_schema(project_info, &pool) @@ -510,14 +515,63 @@ impl Collection { .expect("Error getting parsed schema for pipeline"); parsed_schemas.push(parsed_schema); } - let mut pipelines: Vec<(Pipeline, _)> = pipelines.into_iter().zip(parsed_schemas).collect(); + let pipelines: Vec<(Pipeline, HashMap)> = + pipelines.into_iter().zip(parsed_schemas).collect(); - let args = args.unwrap_or_default(); - let args = args.as_object().context("args must be a JSON object")?; + let batch_size = args + .remove("batch_size") + .map(|x| x.try_to_u64()) + .unwrap_or(Ok(100))?; + + let parallel_batches = args + .get("parallel_batches") + .map(|x| x.try_to_u64()) + .unwrap_or(Ok(1))? as usize; let progress_bar = utils::default_progress_bar(documents.len() as u64); progress_bar.println("Upserting Documents..."); + let mut set = JoinSet::new(); + for batch in documents.chunks(batch_size as usize) { + if set.len() < parallel_batches { + let local_self = self.clone(); + let local_batch = batch.to_owned(); + let local_args = args.clone(); + let local_pipelines = pipelines.clone(); + let local_pool = pool.clone(); + set.spawn(async move { + local_self + ._upsert_documents(local_batch, local_args, local_pipelines, local_pool) + .await + }); + } else { + if let Some(res) = set.join_next().await { + res??; + progress_bar.inc(batch_size); + } + } + } + + while let Some(res) = set.join_next().await { + res??; + progress_bar.inc(batch_size); + } + + progress_bar.println("Done Upserting Documents\n"); + progress_bar.finish(); + + Ok(()) + } + + async fn _upsert_documents( + self, + batch: Vec, + args: serde_json::Map, + mut pipelines: Vec<(Pipeline, HashMap)>, + pool: Pool, + ) -> anyhow::Result<()> { + let project_info = &self.database_data.as_ref().unwrap().project_info; + let query = if args .get("merge") .map(|v| v.as_bool().unwrap_or(false)) @@ -539,111 +593,99 @@ impl Collection { ) }; - let batch_size = args - .get("batch_size") - .map(TryToNumeric::try_to_u64) - .unwrap_or(Ok(100))?; - - for batch in documents.chunks(batch_size as usize) { - let mut transaction = pool.begin().await?; - - let mut query_values = String::new(); - let mut binding_parameter_counter = 1; - for _ in 0..batch.len() { - query_values = format!( - "{query_values}, (${}, ${}, ${})", - binding_parameter_counter, - binding_parameter_counter + 1, - binding_parameter_counter + 2 - ); - binding_parameter_counter += 3; - } + let mut transaction = pool.begin().await?; - let query = query.replace( - "{values_parameters}", - &query_values.chars().skip(1).collect::(), - ); - let query = query.replace( - "{binding_parameter}", - &format!("${binding_parameter_counter}"), + let mut query_values = String::new(); + let mut binding_parameter_counter = 1; + for _ in 0..batch.len() { + query_values = format!( + "{query_values}, (${}, ${}, ${})", + binding_parameter_counter, + binding_parameter_counter + 1, + binding_parameter_counter + 2 ); + binding_parameter_counter += 3; + } - let mut query = sqlx::query_as(&query); - - let mut source_uuids = vec![]; - for document in batch { - let id = document - .get("id") - .context("`id` must be a key in document")? - .to_string(); - let md5_digest = md5::compute(id.as_bytes()); - let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; - source_uuids.push(source_uuid); - - let start = SystemTime::now(); - let timestamp = start - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(); - - let versions: HashMap = document - .as_object() - .context("document must be an object")? - .iter() - .try_fold(HashMap::new(), |mut acc, (key, value)| { - let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); - let md5_digest = format!("{md5_digest:x}"); - acc.insert( - key.to_owned(), - serde_json::json!({ - "last_updated": timestamp, - "md5": md5_digest - }), - ); - anyhow::Ok(acc) - })?; - let versions = serde_json::to_value(versions)?; - - query = query.bind(source_uuid).bind(document).bind(versions); - } + let query = query.replace( + "{values_parameters}", + &query_values.chars().skip(1).collect::(), + ); + let query = query.replace( + "{binding_parameter}", + &format!("${binding_parameter_counter}"), + ); - let results: Vec<(i64, Option)> = query - .bind(source_uuids) - .fetch_all(&mut *transaction) - .await?; + let mut query = sqlx::query_as(&query); + + let mut source_uuids = vec![]; + for document in &batch { + let id = document + .get("id") + .context("`id` must be a key in document")? + .to_string(); + let md5_digest = md5::compute(id.as_bytes()); + let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?; + source_uuids.push(source_uuid); + + let start = SystemTime::now(); + let timestamp = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); + + let versions: HashMap = document + .as_object() + .context("document must be an object")? + .iter() + .try_fold(HashMap::new(), |mut acc, (key, value)| { + let md5_digest = md5::compute(serde_json::to_string(value)?.as_bytes()); + let md5_digest = format!("{md5_digest:x}"); + acc.insert( + key.to_owned(), + serde_json::json!({ + "last_updated": timestamp, + "md5": md5_digest + }), + ); + anyhow::Ok(acc) + })?; + let versions = serde_json::to_value(versions)?; - let dp: Vec<(i64, Json, Option)> = results - .into_iter() - .zip(batch) - .map(|((id, previous_document), document)| { - (id, document.to_owned(), previous_document) + query = query.bind(source_uuid).bind(document).bind(versions); + } + + let results: Vec<(i64, Option)> = query + .bind(source_uuids) + .fetch_all(&mut *transaction) + .await?; + + let dp: Vec<(i64, Json, Option)> = results + .into_iter() + .zip(batch) + .map(|((id, previous_document), document)| (id, document.to_owned(), previous_document)) + .collect(); + + for (pipeline, parsed_schema) in &mut pipelines { + let ids_to_run_on: Vec = dp + .iter() + .filter(|(_, document, previous_document)| match previous_document { + Some(previous_document) => parsed_schema + .iter() + .any(|(key, _)| document[key] != previous_document[key]), + None => true, }) + .map(|(document_id, _, _)| *document_id) .collect(); - - for (pipeline, parsed_schema) in &mut pipelines { - let ids_to_run_on: Vec = dp - .iter() - .filter(|(_, document, previous_document)| match previous_document { - Some(previous_document) => parsed_schema - .iter() - .any(|(key, _)| document[key] != previous_document[key]), - None => true, - }) - .map(|(document_id, _, _)| *document_id) - .collect(); - if !ids_to_run_on.is_empty() { - pipeline - .sync_documents(ids_to_run_on, project_info, &mut transaction) - .await - .expect("Failed to execute pipeline"); - } + if !ids_to_run_on.is_empty() { + pipeline + .sync_documents(ids_to_run_on, project_info, &mut transaction) + .await + .expect("Failed to execute pipeline"); } - - transaction.commit().await?; - progress_bar.inc(batch_size); } - progress_bar.println("Done Upserting Documents\n"); - progress_bar.finish(); + + transaction.commit().await?; Ok(()) } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 87b99657c..b805fe38e 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -431,6 +431,86 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_add_pipeline_and_upsert_documents_with_parallel_batches() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let collection_name = "test_r_c_capaud_107"; + let pipeline_name = "test_r_p_capaud_6"; + let mut pipeline = Pipeline::new( + pipeline_name, + Some( + json!({ + "title": { + "semantic_search": { + "model": "intfloat/e5-small" + } + }, + "body": { + "splitter": { + "model": "recursive_character", + "parameters": { + "chunk_size": 1000, + "chunk_overlap": 40 + } + }, + "semantic_search": { + "model": "hkunlp/instructor-base", + "parameters": { + "instruction": "Represent the Wikipedia document for retrieval" + } + }, + "full_text_search": { + "configuration": "english" + } + } + }) + .into(), + ), + )?; + let mut collection = Collection::new(collection_name, None)?; + collection.add_pipeline(&mut pipeline).await?; + let documents = generate_dummy_documents(20); + collection + .upsert_documents( + documents.clone(), + Some( + json!({ + "batch_size": 4, + "parallel_batches": 5 + }) + .into(), + ), + ) + .await?; + let pool = get_or_initialize_pool(&None).await?; + let documents_table = format!("{}.documents", collection_name); + let queried_documents: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", documents_table)) + .fetch_all(&pool) + .await?; + assert!(queried_documents.len() == 20); + let chunks_table = format!("{}_{}.title_chunks", collection_name, pipeline_name); + let title_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(title_chunks.len() == 20); + let chunks_table = format!("{}_{}.body_chunks", collection_name, pipeline_name); + let body_chunks: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", chunks_table)) + .fetch_all(&pool) + .await?; + assert!(body_chunks.len() == 120); + let tsvectors_table = format!("{}_{}.body_tsvectors", collection_name, pipeline_name); + let tsvectors: Vec = + sqlx::query_as(&query_builder!("SELECT * FROM %s", tsvectors_table)) + .fetch_all(&pool) + .await?; + assert!(tsvectors.len() == 120); + collection.archive().await?; + Ok(()) + } + #[tokio::test] async fn can_upsert_documents_and_add_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok();