diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index f8107d050..d8bf9e854 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -3,6 +3,7 @@ use indicatif::MultiProgress; use itertools::Itertools; use regex::Regex; use rust_bridge::{alias, alias_methods}; +use sea_query::Alias; use sea_query::{Expr, NullOrdering, Order, PostgresQueryBuilder, Query}; use sea_query_binder::SqlxBinder; use serde_json::json; @@ -656,8 +657,9 @@ impl Collection { /// Each object must have a `field` key with the name of the field to order by, and a `direction` /// key with the value `asc` or `desc`. /// * `last_row_id` - The id of the last document returned - /// * `offset` - The number of documents to skip before returning results. - /// * `filter` - A JSON object specifying the filter to apply to the documents. + /// * `offset` - The number of documents to skip before returning results + /// * `filter` - A JSON object specifying the filter to apply to the documents + /// * `keys` - a JSON array specifying the document keys to return /// /// # Example /// @@ -691,9 +693,33 @@ impl Collection { self.documents_table_name.to_table_tuple(), SIden::Str("documents"), ) - .expr(Expr::cust("*")) // Adds the * in SELECT * FROM + .columns([ + SIden::Str("id"), + SIden::Str("created_at"), + SIden::Str("source_uuid"), + SIden::Str("version"), + ]) .limit(limit); + if let Some(keys) = args.remove("keys") { + let document_queries = keys + .as_array() + .context("`keys` must be an array")? + .iter() + .map(|d| { + let key = d.as_str().context("`key` value must be a string")?; + anyhow::Ok(format!("'{key}', document #> '{{{key}}}'")) + }) + .collect::>>()? + .join(","); + query.expr_as( + Expr::cust(format!("jsonb_build_object({document_queries})")), + Alias::new("document"), + ); + } else { + query.column(SIden::Str("document")); + } + if let Some(order_by) = args.remove("order_by") { let order_by_builder = order_by_builder::OrderByBuilder::new(order_by, "documents", "document").build()?; diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index a7161d3c5..c32a4e5bb 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -1343,6 +1343,50 @@ mod tests { Ok(()) } + #[tokio::test] + async fn can_get_document_keys_get_documents() -> anyhow::Result<()> { + internal_init_logger(None, None).ok(); + let mut collection = Collection::new("test r_c_cuafgd_1", None)?; + + let documents = vec![ + serde_json::json!({"id": 1, "random_key": 10, "nested": {"nested2": "test" } , "text": "hello world 1"}).into(), + serde_json::json!({"id": 2, "random_key": 11, "text": "hello world 2"}).into(), + serde_json::json!({"id": 3, "random_key": 12, "text": "hello world 3"}).into(), + ]; + collection.upsert_documents(documents.clone(), None).await?; + + let documents = collection + .get_documents(Some( + serde_json::json!({ + "keys": [ + "id", + "random_key", + "nested,nested2" + ] + }) + .into(), + )) + .await?; + assert!(!documents[0]["document"] + .as_object() + .unwrap() + .contains_key("text")); + assert!(documents[0]["document"] + .as_object() + .unwrap() + .contains_key("id")); + assert!(documents[0]["document"] + .as_object() + .unwrap() + .contains_key("random_key")); + assert!(documents[0]["document"] + .as_object() + .unwrap() + .contains_key("nested,nested2")); + collection.archive().await?; + Ok(()) + } + #[tokio::test] async fn can_paginate_get_documents() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); diff --git a/pgml-sdks/pgml/src/queries.rs b/pgml-sdks/pgml/src/queries.rs index 1ea7001bf..6c0f788dd 100644 --- a/pgml-sdks/pgml/src/queries.rs +++ b/pgml-sdks/pgml/src/queries.rs @@ -30,8 +30,8 @@ CREATE TABLE IF NOT EXISTS %s ( id serial8 PRIMARY KEY, created_at timestamp NOT NULL DEFAULT now(), source_uuid uuid NOT NULL, - document jsonb NOT NULL, version jsonb NOT NULL DEFAULT '{}'::jsonb, + document jsonb NOT NULL, UNIQUE (source_uuid) ); "#;