From df0ecd94acc9a18d67a6c86bd9a1823a6776f143 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:43:56 -0700 Subject: [PATCH 1/2] Gated rust_bridge behind a feature flag --- pgml-sdks/pgml/Cargo.lock | 3 ++ pgml-sdks/pgml/Cargo.toml | 10 ++-- pgml-sdks/pgml/src/builtins.rs | 22 +++++--- pgml-sdks/pgml/src/collection.rs | 63 ++++++++++++---------- pgml-sdks/pgml/src/model.rs | 9 ++-- pgml-sdks/pgml/src/open_source_ai.rs | 22 +++++--- pgml-sdks/pgml/src/pipeline.rs | 9 ++-- pgml-sdks/pgml/src/query_builder.rs | 12 +++-- pgml-sdks/pgml/src/query_runner.rs | 26 +++++---- pgml-sdks/pgml/src/splitter.rs | 9 ++-- pgml-sdks/pgml/src/transformer_pipeline.rs | 12 +++-- pgml-sdks/pgml/src/types.rs | 11 ++-- 12 files changed, 132 insertions(+), 76 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.lock b/pgml-sdks/pgml/Cargo.lock index 784b528a7..2f600f25b 100644 --- a/pgml-sdks/pgml/Cargo.lock +++ b/pgml-sdks/pgml/Cargo.lock @@ -1963,6 +1963,7 @@ dependencies = [ [[package]] name = "rust_bridge" version = "0.1.0" +source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "rust_bridge_macros", "rust_bridge_traits", @@ -1971,6 +1972,7 @@ dependencies = [ [[package]] name = "rust_bridge_macros" version = "0.1.0" +source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "anyhow", "proc-macro2", @@ -1981,6 +1983,7 @@ dependencies = [ [[package]] name = "rust_bridge_traits" version = "0.1.0" +source = "git+https://github.com/postgresml/postgresml#b949d45a2353141b3635d7f88b1fdd9cf78fa666" dependencies = [ "neon", ] diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 008e5d838..b1f5044c8 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -14,7 +14,8 @@ name = "pgml" crate-type = ["lib", "cdylib"] [dependencies] -rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0"} +# rust_bridge = {path = "../rust-bridge/rust-bridge", version = "0.1.0", optional = true } +rust_bridge = {git = "https://github.com/postgresml/postgresml", version = "0.1.0", optional = true } sqlx = { version = "0.7.3", features = [ "runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } serde_json = "1.0.9" anyhow = "1.0.9" @@ -50,6 +51,7 @@ serde_with = "3.8.1" [features] default = [] -python = ["dep:pyo3", "dep:pyo3-asyncio"] -javascript = ["dep:neon"] -c = [] +rust_bridge = ["dep:rust_bridge"] +python = ["rust_bridge", "dep:pyo3", "dep:pyo3-asyncio"] +javascript = ["rust_bridge", "dep:neon"] +c = ["rust_bridge"] diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 0a1d60b71..f8e913f2c 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -1,23 +1,29 @@ use anyhow::Context; -use rust_bridge::{alias, alias_methods}; use sqlx::Row; use tracing::instrument; -/// Provides access to builtin database methods -#[derive(alias, Debug, Clone)] -pub struct Builtins { - database_url: Option, -} - use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + #[cfg(feature = "python")] use crate::{query_runner::QueryRunnerPython, types::JsonPython}; #[cfg(feature = "c")] use crate::{languages::c::JsonC, query_runner::QueryRunnerC}; -#[alias_methods(new, query, transform, embed, embed_batch)] +/// Provides access to builtin database methods +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] +pub struct Builtins { + database_url: Option, +} + +#[cfg_attr( + feature = "rust_bridge", + alias_methods(new, query, transform, embed, embed_batch) +)] impl Builtins { pub fn new(database_url: Option) -> Self { Self { database_url } diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index 957f15f02..f2f99ab98 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -2,7 +2,6 @@ use anyhow::Context; use indicatif::MultiProgress; use itertools::Itertools; use regex::Regex; -use rust_bridge::{alias, alias_methods}; use sea_query::Alias; use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; @@ -35,6 +34,12 @@ use crate::{ utils, }; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + +#[cfg(feature = "c")] +use crate::languages::c::GeneralJsonAsyncIteratorC; + #[cfg(feature = "python")] use crate::{ pipeline::PipelinePython, @@ -43,7 +48,7 @@ use crate::{ }; /// A RAGStream Struct -#[derive(alias)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] #[allow(dead_code)] pub struct RAGStream { general_json_async_iterator: Option, @@ -57,7 +62,7 @@ impl Clone for RAGStream { } } -#[alias_methods(stream, sources)] +#[cfg_attr(feature = "rust_bridge", alias_methods(stream, sources))] impl RAGStream { pub fn stream(&mut self) -> anyhow::Result { self.general_json_async_iterator @@ -140,7 +145,8 @@ pub(crate) struct CollectionDatabaseData { } /// A collection of documents -#[derive(alias, Debug, Clone)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] pub struct Collection { pub(crate) name: String, pub(crate) database_url: Option, @@ -149,29 +155,32 @@ pub struct Collection { pub(crate) database_data: Option, } -#[alias_methods( - new, - upsert_documents, - get_documents, - delete_documents, - get_pipelines, - get_pipeline, - add_pipeline, - remove_pipeline, - enable_pipeline, - disable_pipeline, - search, - add_search_event, - vector_search, - query, - rag, - rag_stream, - exists, - archive, - upsert_directory, - upsert_file, - generate_er_diagram, - get_pipeline_status +#[cfg_attr( + feature = "rust_bridge", + alias_methods( + new, + upsert_documents, + get_documents, + delete_documents, + get_pipelines, + get_pipeline, + add_pipeline, + remove_pipeline, + enable_pipeline, + disable_pipeline, + search, + add_search_event, + vector_search, + query, + rag, + rag_stream, + exists, + archive, + upsert_directory, + upsert_file, + generate_er_diagram, + get_pipeline_status + ) )] impl Collection { /// Creates a new [Collection] diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index 539c6d1cd..81079400f 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -1,4 +1,3 @@ -use rust_bridge::{alias, alias_methods}; use sqlx::{Pool, Postgres}; use tracing::instrument; @@ -14,6 +13,9 @@ use crate::types::JsonPython; #[cfg(feature = "c")] use crate::languages::c::JsonC; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + /// A few notes on the following enums: /// - Sqlx does provide type derivation for enums, but it's not very good /// - Queries using these enums require a number of additional queries to get their oids and @@ -55,7 +57,8 @@ pub(crate) struct ModelDatabaseData { } /// A model used for embedding, inference, etc... -#[derive(alias, Debug, Clone)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] pub struct Model { pub(crate) name: String, pub(crate) runtime: ModelRuntime, @@ -69,7 +72,7 @@ impl Default for Model { } } -#[alias_methods(new, transform)] +#[cfg_attr(feature = "rust_bridge", alias_methods(new, transform))] impl Model { /// Creates a new [Model] pub fn new(name: Option, source: Option, parameters: Option) -> Self { diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index d7810327d..f582ee80d 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -1,6 +1,5 @@ use anyhow::Context; use futures::{Stream, StreamExt}; -use rust_bridge::{alias, alias_methods}; use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; @@ -10,6 +9,9 @@ use crate::{ TransformerPipeline, }; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + #[cfg(feature = "python")] use crate::types::{GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython, JsonPython}; @@ -20,7 +22,8 @@ use crate::{ }; /// A drop in replacement for OpenAI -#[derive(alias, Debug, Clone)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] pub struct OpenSourceAI { database_url: Option, } @@ -166,12 +169,15 @@ impl Iterator for AsyncToSyncJsonIterator { } } -#[alias_methods( - new, - chat_completions_create, - chat_completions_create_async, - chat_completions_create_stream, - chat_completions_create_stream_async +#[cfg_attr( + feature = "rust_bridge", + alias_methods( + new, + chat_completions_create, + chat_completions_create_async, + chat_completions_create_stream, + chat_completions_create_stream_async + ) )] impl OpenSourceAI { /// Creates a new [OpenSourceAI] diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index e082e9e4b..33e552cc7 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -1,5 +1,4 @@ use anyhow::Context; -use rust_bridge::{alias, alias_methods}; use serde::Deserialize; use serde_json::json; use sqlx::{Executor, PgConnection, Pool, Postgres, Transaction}; @@ -16,6 +15,9 @@ use crate::{ types::{DateTime, Json, TryToNumeric}, }; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + #[cfg(feature = "python")] use crate::types::JsonPython; @@ -179,7 +181,8 @@ pub struct PipelineDatabaseData { } /// A pipeline that describes transformations to documents -#[derive(alias, Debug, Clone)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] pub struct Pipeline { pub(crate) name: String, pub(crate) schema: Option, @@ -205,7 +208,7 @@ fn json_to_schema(schema: &Json) -> anyhow::Result { }) } -#[alias_methods(new)] +#[cfg_attr(feature = "rust_bridge", alias_methods(new))] impl Pipeline { /// Creates a [Pipeline] /// diff --git a/pgml-sdks/pgml/src/query_builder.rs b/pgml-sdks/pgml/src/query_builder.rs index 46b4255b0..4e3b9babf 100644 --- a/pgml-sdks/pgml/src/query_builder.rs +++ b/pgml-sdks/pgml/src/query_builder.rs @@ -3,26 +3,32 @@ // No new things should be added here, instead add new items to collection.vector_search use anyhow::Context; -use rust_bridge::{alias, alias_methods}; use serde_json::json; use tracing::instrument; use crate::{pipeline::Pipeline, types::Json, Collection}; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + #[cfg(feature = "python")] use crate::{pipeline::PipelinePython, types::JsonPython}; #[cfg(feature = "c")] use crate::{languages::c::JsonC, pipeline::PipelineC}; -#[derive(alias, Clone, Debug)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Clone, Debug)] pub struct QueryBuilder { collection: Collection, query: Json, pipeline: Option, } -#[alias_methods(limit, filter, vector_recall, to_full_string, fetch_all(skip = "C"))] +#[cfg_attr( + feature = "rust_bridge", + alias_methods(limit, filter, vector_recall, to_full_string, fetch_all(skip = "C")) +)] impl QueryBuilder { pub fn new(collection: Collection) -> Self { let query = json!({ diff --git a/pgml-sdks/pgml/src/query_runner.rs b/pgml-sdks/pgml/src/query_runner.rs index cb5ba77cd..0e3ad396c 100644 --- a/pgml-sdks/pgml/src/query_runner.rs +++ b/pgml-sdks/pgml/src/query_runner.rs @@ -1,4 +1,3 @@ -use rust_bridge::{alias, alias_methods}; use sqlx::postgres::PgArguments; use sqlx::query::Query; use sqlx::{Postgres, Row}; @@ -11,6 +10,9 @@ use crate::types::JsonPython; #[cfg(feature = "c")] use crate::languages::c::JsonC; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + #[derive(Clone, Debug)] enum BindValue { String(String), @@ -20,21 +22,25 @@ enum BindValue { Json(Json), } -#[derive(alias, Clone, Debug)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Clone, Debug)] pub struct QueryRunner { query: String, bind_values: Vec, database_url: Option, } -#[alias_methods( - fetch_all, - execute, - bind_string, - bind_int, - bind_float, - bind_bool, - bind_json +#[cfg_attr( + feature = "rust_bridge", + alias_methods( + fetch_all, + execute, + bind_string, + bind_int, + bind_float, + bind_bool, + bind_json + ) )] impl QueryRunner { pub fn new(query: &str, database_url: Option) -> Self { diff --git a/pgml-sdks/pgml/src/splitter.rs b/pgml-sdks/pgml/src/splitter.rs index b7dd6c74d..f82d13803 100644 --- a/pgml-sdks/pgml/src/splitter.rs +++ b/pgml-sdks/pgml/src/splitter.rs @@ -1,4 +1,3 @@ -use rust_bridge::{alias, alias_methods}; use sqlx::{postgres::PgConnection, Pool, Postgres}; use tracing::instrument; @@ -14,6 +13,9 @@ use crate::types::JsonPython; #[cfg(feature = "c")] use crate::languages::c::JsonC; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + #[allow(dead_code)] #[derive(Debug, Clone)] pub(crate) struct SplitterDatabaseData { @@ -22,7 +24,8 @@ pub(crate) struct SplitterDatabaseData { } /// A text splitter -#[derive(alias, Debug, Clone)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] pub struct Splitter { pub(crate) name: String, pub(crate) parameters: Json, @@ -35,7 +38,7 @@ impl Default for Splitter { } } -#[alias_methods(new)] +#[cfg_attr(feature = "rust_bridge", alias_methods(new))] impl Splitter { /// Creates a new [Splitter] /// diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 18160198d..bb44e591a 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -1,10 +1,13 @@ use anyhow::Context; -use rust_bridge::{alias, alias_methods}; use sqlx::Row; use tracing::instrument; +#[cfg(feature = "rust_bridge")] +use rust_bridge::{alias, alias_methods}; + /// Provides access to builtin database methods -#[derive(alias, Debug, Clone)] +#[cfg_attr(feature = "rust_bridge", derive(alias))] +#[derive(Debug, Clone)] pub struct TransformerPipeline { task: Json, database_url: Option, @@ -19,7 +22,10 @@ use crate::types::{GeneralJsonAsyncIteratorPython, JsonPython}; #[cfg(feature = "c")] use crate::{languages::c::GeneralJsonAsyncIteratorC, languages::c::JsonC}; -#[alias_methods(new, transform, transform_stream)] +#[cfg_attr( + feature = "rust_bridge", + alias_methods(new, transform, transform_stream) +)] impl TransformerPipeline { /// Creates a new [TransformerPipeline] /// diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index 2d47de710..4b57f0227 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -1,12 +1,14 @@ use anyhow::Context; use futures::{stream::BoxStream, Stream, StreamExt}; use itertools::Itertools; -use rust_bridge::alias_manual; use sea_query::Iden; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::ops::{Deref, DerefMut}; +#[cfg(feature = "rust_bridge")] +use rust_bridge::alias_manual; + #[derive(Serialize, Deserialize)] pub struct CustomU64Convertor(pub Value); @@ -31,7 +33,8 @@ impl From for u64 { } /// A wrapper around `serde_json::Value` -#[derive(alias_manual, sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] +#[cfg_attr(feature = "rust_bridge", derive(alias_manual))] +#[derive(sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] #[sqlx(transparent)] pub struct Json(pub serde_json::Value); @@ -150,7 +153,7 @@ impl IntoTableNameAndSchema for String { } /// A wrapper around `BoxStream<'static, anyhow::Result>` -#[derive(alias_manual)] +#[cfg_attr(feature = "rust_bridge", derive(alias_manual))] pub struct GeneralJsonAsyncIterator(pub BoxStream<'static, anyhow::Result>); impl Stream for GeneralJsonAsyncIterator { @@ -165,7 +168,7 @@ impl Stream for GeneralJsonAsyncIterator { } /// A wrapper around `Box> + Send>` -#[derive(alias_manual)] +#[cfg_attr(feature = "rust_bridge", derive(alias_manual))] pub struct GeneralJsonIterator(pub Box> + Send>); impl Iterator for GeneralJsonIterator { From 7379a0e1e5480407d82614b799b7c4096ae21ac5 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:54:29 -0700 Subject: [PATCH 2/2] Add an example for rag --- pgml-sdks/pgml/src/collection.rs | 60 ++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index f2f99ab98..1cd6ccd8c 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -1137,6 +1137,65 @@ impl Collection { .collect()) } + /// Performs rag on the [Collection] + /// + /// # Arguments + /// * `query` - The query to search for + /// * `pipeline` - The [Pipeline] to use for the search + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// let results = collection.rag(json!({ + /// "CONTEXT": { + /// "vector_search": { + /// "query": { + /// "fields": { + /// "body": { + /// "query": "Test document: 2", + /// "parameters": { + /// "prompt": "query: " + /// } + /// }, + /// }, + /// }, + /// "document": { + /// "keys": [ + /// "id" + /// ] + /// }, + /// "limit": 2 + /// }, + /// "aggregate": { + /// "join": "\n" + /// } + /// }, + /// "CUSTOM": { + /// "sql": "SELECT 'test'" + /// }, + /// "chat": { + /// "model": "meta-llama/Meta-Llama-3-8B-Instruct", + /// "messages": [ + /// { + /// "role": "system", + /// "content": "You are a friendly and helpful chatbot" + /// }, + /// { + /// "role": "user", + /// "content": "Some text with {CONTEXT} - {CUSTOM}", + /// } + /// ], + /// "max_tokens": 10 + /// } + /// }).into(), &mut pipeline).await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn rag(&self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -1147,6 +1206,7 @@ impl Collection { Ok(std::mem::take(&mut results[0].0)) } + /// Same as rag buit returns a stream of results #[instrument(skip(self))] pub async fn rag_stream( &self,