From 51d005775fad372bfb27d22308ed3e4754b22579 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 22 May 2024 08:59:47 -0700 Subject: [PATCH 1/3] Add pgml.embed() to the builtins --- pgml-sdks/pgml/Cargo.lock | 2 +- pgml-sdks/pgml/Cargo.toml | 2 +- pgml-sdks/pgml/src/builtins.rs | 18 +++++++++++++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 74f0c7825..fdb5066eb 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1592,7 +1592,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pgml" -version = "1.0.2" +version = "1.0.4" dependencies = [ "anyhow", "async-trait", diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index dce50c859..7837e62fb 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "1.0.3" +version = "1.0.4" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 6a4200457..62b985aa8 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -13,7 +13,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; -#[alias_methods(new, query, transform)] +#[alias_methods(new, query, transform, embed)] impl Builtins { pub fn new(database_url: Option) -> Self { Self { database_url } @@ -87,6 +87,22 @@ impl Builtins { let results = results.first().unwrap().get::(0); Ok(Json(results)) } + + /// Run the built-in `pgml.embed()` function. + /// + /// # Arguments + /// + /// * `model` - The model to use. + /// * `text` - The text to embed. + /// + pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result { + let pool = get_or_initialize_pool(&self.database_url).await?; + let query = sqlx::query("SELECT pgml.embed($1, $2)"); + let result = query.bind(model).bind(text).fetch_one(&pool).await?; + let result = result.get::, _>(0); + let result = serde_json::to_value(result)?; + Ok(Json(result)) + } } #[cfg(test)] From 8b0b9ac46375f39a399ea056481284d71bc62128 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 22 May 2024 09:16:07 -0700 Subject: [PATCH 2/3] embed batch --- pgml-sdks/pgml/src/builtins.rs | 38 ++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 62b985aa8..15418fd46 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use rust_bridge::{alias, alias_methods}; use sqlx::Row; use tracing::instrument; @@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; -#[alias_methods(new, query, transform, embed)] +#[alias_methods(new, query, transform, embed, embed_batch)] impl Builtins { pub fn new(database_url: Option) -> Self { Self { database_url } @@ -97,12 +98,45 @@ impl Builtins { /// pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; - let query = sqlx::query("SELECT pgml.embed($1, $2)"); + let query = sqlx::query("SELECT embed FROM pgml.embed($1, $2)"); let result = query.bind(model).bind(text).fetch_one(&pool).await?; let result = result.get::, _>(0); let result = serde_json::to_value(result)?; Ok(Json(result)) } + + /// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs. + /// + /// # Arguments + /// + /// * `model` - The model to use. + /// * `texts` - The texts to embed. + /// + pub async fn embed_batch(&self, model: &str, texts: Json) -> anyhow::Result { + let texts = texts + .0 + .as_array() + .with_context(|| "embed_batch takes an array of texts")? + .into_iter() + .map(|v| { + v.as_str() + .with_context(|| "only text embeddings are supported") + .unwrap() + .to_string() + }) + .collect::>(); + let pool = get_or_initialize_pool(&self.database_url).await?; + let query = sqlx::query("SELECT embed AS embed_batch FROM pgml.embed($1, $2)"); + let results = query + .bind(model) + .bind(texts) + .fetch_all(&pool) + .await? + .into_iter() + .map(|embeddings| embeddings.get::, _>(0)) + .collect::>>(); + Ok(Json(serde_json::to_value(results)?)) + } } #[cfg(test)] From e75db142464d264e102ac06068c10e14b4fd0a5e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 22 May 2024 10:12:44 -0700 Subject: [PATCH 3/3] Tests --- pgml-sdks/pgml/python/tests/requirements.txt | 2 ++ pgml-sdks/pgml/python/tests/test.py | 12 +++++++++ pgml-sdks/pgml/src/builtins.rs | 26 +++++++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 pgml-sdks/pgml/python/tests/requirements.txt diff --git a/pgml-sdks/pgml/python/tests/requirements.txt b/pgml-sdks/pgml/python/tests/requirements.txt new file mode 100644 index 000000000..ee4ba0186 --- /dev/null +++ b/pgml-sdks/pgml/python/tests/requirements.txt @@ -0,0 +1,2 @@ +pytest +pytest-asyncio diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 87adf5ba7..b7367103a 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -72,6 +72,18 @@ def test_can_create_builtins(): builtins = pgml.Builtins() assert builtins is not None +@pytest.mark.asyncio +async def test_can_embed_with_builtins(): + builtins = pgml.Builtins() + result = await builtins.embed("intfloat/e5-small-v2", "test") + assert result is not None + +@pytest.mark.asyncio +async def test_can_embed_batch_with_builtins(): + builtins = pgml.Builtins() + result = await builtins.embed_batch("intfloat/e5-small-v2", ["test"]) + assert result is not None + ################################################### ## Test searches ################################## diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 15418fd46..531ae4fa3 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -116,7 +116,7 @@ impl Builtins { let texts = texts .0 .as_array() - .with_context(|| "embed_batch takes an array of texts")? + .with_context(|| "embed_batch takes an array of strings")? .into_iter() .map(|v| { v.as_str() @@ -167,4 +167,28 @@ mod tests { assert!(results.as_array().is_some()); Ok(()) } + + #[tokio::test] + async fn can_embed() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let builtins = Builtins::new(None); + let results = builtins.embed("intfloat/e5-small-v2", "test").await?; + assert!(results.as_array().is_some()); + Ok(()) + } + + #[tokio::test] + async fn can_embed_batch() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let builtins = Builtins::new(None); + let results = builtins + .embed_batch( + "intfloat/e5-small-v2", + Json(serde_json::json!(["test", "test2",])), + ) + .await?; + assert!(results.as_array().is_some()); + assert_eq!(results.as_array().unwrap().len(), 2); + Ok(()) + } }