From 8bc7a30bf1e9a36f471831d186710c6cb947ccf7 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 15 May 2024 10:37:55 -0700 Subject: [PATCH] SDK - Patch parallel uploads --- pgml-sdks/pgml/src/collection.rs | 30 ++++++++++++++---------------- pgml-sdks/pgml/src/lib.rs | 11 +++++++---- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 27f95813f..2f1291e82 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -533,23 +533,21 @@ impl Collection { let mut set = JoinSet::new(); for batch in documents.chunks(batch_size as usize) { - if set.len() < parallel_batches { - let local_self = self.clone(); - let local_batch = batch.to_owned(); - let local_args = args.clone(); - let local_pipelines = pipelines.clone(); - let local_pool = pool.clone(); - set.spawn(async move { - local_self - ._upsert_documents(local_batch, local_args, local_pipelines, local_pool) - .await - }); - } else { - if let Some(res) = set.join_next().await { - res??; - progress_bar.inc(batch_size); - } + if set.len() >= parallel_batches { + set.join_next().await.unwrap()??; + progress_bar.inc(batch_size); } + + let local_self = self.clone(); + let local_batch = batch.to_owned(); + let local_args = args.clone(); + let local_pipelines = pipelines.clone(); + let local_pool = pool.clone(); + set.spawn(async move { + local_self + ._upsert_documents(local_batch, local_args, local_pipelines, local_pool) + .await + }); } while let Some(res) = set.join_next().await { diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index b805fe38e..ddfc37341 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -442,7 +442,10 @@ mod tests { json!({ "title": { "semantic_search": { - "model": "intfloat/e5-small" + "model": "intfloat/e5-small-v2", + "parameters": { + "prompt": "passage: " + } } }, "body": { @@ -454,9 +457,9 @@ mod tests { } }, "semantic_search": { - "model": "hkunlp/instructor-base", + "model": "intfloat/e5-small-v2", "parameters": { - "instruction": "Represent the Wikipedia document for retrieval" + "prompt": "passage: " } }, "full_text_search": { @@ -475,7 +478,7 @@ mod tests { documents.clone(), Some( json!({ - "batch_size": 4, + "batch_size": 2, "parallel_batches": 5 }) .into(),