diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 17d8b8a3a..6c0f75838 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1746,7 +1746,7 @@ dependencies = [ [[package]] name = "pgml" -version = "2.8.5" +version = "2.9.0" dependencies = [ "anyhow", "blas", @@ -1934,6 +1934,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "postgres" version = "0.19.7" @@ -2030,15 +2036,17 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" dependencies = [ + "anyhow", "cfg-if", "indoc", "libc", "memoffset", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -2047,9 +2055,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" dependencies = [ "once_cell", "target-lexicon", @@ -2057,9 +2065,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" dependencies = [ "libc", "pyo3-build-config", @@ -2067,9 +2075,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2079,12 +2087,13 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" dependencies = [ "heck", "proc-macro2", + "pyo3-build-config", "quote 1.0.35", "syn 2.0.46", ] diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index c596e2d53..7787eb25c 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.8.5" +version = "2.9.0" edition = "2021" [lib] @@ -41,7 +41,7 @@ ndarray-stats = "0.5.1" parking_lot = "0.12" pgrx = "=0.11.3" pgrx-pg-sys = "=0.11.3" -pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.20.0", features = ["anyhow", "auto-initialize"], optional = true } rand = "0.8" rmp-serde = { version = "1.1" } signal-hook = "0.3" diff --git a/pgml-extension/sql/pgml--2.8.5--2.9.0.sql b/pgml-extension/sql/pgml--2.8.5--2.9.0.sql new file mode 100644 index 000000000..a5e152040 --- /dev/null +++ b/pgml-extension/sql/pgml--2.8.5--2.9.0.sql @@ -0,0 +1,15 @@ +-- src/api.rs:613 +-- pgml::api::rank +CREATE FUNCTION pgml."rank"( + "transformer" TEXT, /* &str */ + "query" TEXT, /* &str */ + "documents" TEXT[], /* alloc::vec::Vec<&str> */ + "kwargs" jsonb DEFAULT '{}' /* pgrx::datum::json::JsonB */ +) RETURNS TABLE ( + "corpus_id" bigint, /* i64 */ + "score" double precision, /* f64 */ + "text" TEXT /* core::option::Option */ +) +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'rank_wrapper'; diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 14efde32b..923c6fc70 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -603,7 +603,21 @@ pub fn embed_batch( kwargs: default!(JsonB, "'{}'"), ) -> SetOfIterator<'static, Vec> { match crate::bindings::transformers::embed(transformer, inputs, &kwargs.0) { - Ok(output) => SetOfIterator::new(output.into_iter()), + Ok(output) => SetOfIterator::new(output), + Err(e) => error!("{e}"), + } +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "rank")] +pub fn rank( + transformer: &str, + query: &str, + documents: Vec<&str>, + kwargs: default!(JsonB, "'{}'"), +) -> TableIterator<'static, (name!(corpus_id, i64), name!(score, f64), name!(text, Option))> { + match crate::bindings::transformers::rank(transformer, query, documents, &kwargs.0) { + Ok(output) => TableIterator::new(output.into_iter().map(|x| (x.corpus_id, x.score, x.text))), Err(e) => error!("{e}"), } } diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 9b4f51b9f..33f103e62 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -6,7 +6,8 @@ use std::{collections::HashMap, path::Path}; use anyhow::{anyhow, bail, Context, Result}; use pgrx::*; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple}; +use serde::Deserialize; use serde_json::Value; use crate::create_pymodule; @@ -21,6 +22,59 @@ pub use transform::*; create_pymodule!("/src/bindings/transformers/transformers.py"); +// Need a wrapper so we can implement traits for it +struct Json(Value); + +impl From for Value { + fn from(value: Json) -> Self { + value.0 + } +} + +impl FromPyObject<'_> for Json { + fn extract(ob: &PyAny) -> PyResult { + if ob.is_instance_of::() { + let dict: &PyDict = ob.downcast()?; + let mut json = serde_json::Map::new(); + for (key, value) in dict.iter() { + let value = Json::extract(value)?; + json.insert(String::extract(key)?, value.0); + } + Ok(Self(serde_json::Value::Object(json))) + } else if ob.is_instance_of::() { + let value = bool::extract(ob)?; + Ok(Self(serde_json::Value::Bool(value))) + } else if ob.is_instance_of::() { + let value = i64::extract(ob)?; + Ok(Self(serde_json::Value::Number(value.into()))) + } else if ob.is_instance_of::() { + let value = f64::extract(ob)?; + let value = + serde_json::value::Number::from_f64(value).context("Could not convert f64 to serde_json::Number")?; + Ok(Self(serde_json::Value::Number(value))) + } else if ob.is_instance_of::() { + let value = String::extract(ob)?; + Ok(Self(serde_json::Value::String(value))) + } else if ob.is_instance_of::() { + let value = ob.downcast::()?; + let mut json_values = Vec::new(); + for v in value { + let v = v.extract::()?; + json_values.push(v.0); + } + Ok(Self(serde_json::Value::Array(json_values))) + } else { + if ob.is_none() { + return Ok(Self(serde_json::Value::Null)); + } + Err(anyhow::anyhow!( + "Unsupported type for JSON conversion: {:?}", + ob.get_type() + ))? + } + } +} + pub fn get_model_from(task: &Value) -> Result { Python::with_gil(|py| -> Result { let get_model_from = get_module!(PY_MODULE) @@ -55,6 +109,46 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) - }) } +#[derive(Deserialize)] +pub struct RankResult { + pub corpus_id: i64, + pub score: f64, + pub text: Option, +} + +pub fn rank( + transformer: &str, + query: &str, + documents: Vec<&str>, + kwargs: &serde_json::Value, +) -> Result> { + let kwargs = serde_json::to_string(kwargs)?; + Python::with_gil(|py| -> Result> { + let embed: Py = get_module!(PY_MODULE).getattr(py, "rank").format_traceback(py)?; + let output = embed + .call1( + py, + PyTuple::new( + py, + &[ + transformer.to_string().into_py(py), + query.into_py(py), + documents.into_py(py), + kwargs.into_py(py), + ], + ), + ) + .format_traceback(py)?; + let out: Vec = output.extract(py).format_traceback(py)?; + out.into_iter() + .map(|x| { + let x: RankResult = serde_json::from_value(x.0)?; + Ok(x) + }) + .collect() + }) +} + pub fn finetune_text_classification( task: &Task, dataset: TextClassificationDataset, diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 782dd7908..baa2c2500 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -12,7 +12,7 @@ import orjson from rouge import Rouge from sacrebleu.metrics import BLEU -from sentence_transformers import SentenceTransformer +from sentence_transformers import SentenceTransformer, CrossEncoder from sklearn.metrics import ( mean_squared_error, r2_score, @@ -500,6 +500,33 @@ def transform(task, args, inputs, stream=False): return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() +def create_cross_encoder(transformer): + return CrossEncoder(transformer) + + +def rank_using(model, query, documents, kwargs): + if isinstance(kwargs, str): + kwargs = orjson.loads(kwargs) + + # The score is a numpy float32 before we convert it + return [ + {"score": x.pop("score").item(), **x} + for x in model.rank(query, documents, **kwargs) + ] + + +def rank(transformer, query, documents, kwargs): + kwargs = orjson.loads(kwargs) + + if transformer not in __cache_sentence_transformer_by_name: + __cache_sentence_transformer_by_name[transformer] = create_cross_encoder( + transformer + ) + model = __cache_sentence_transformer_by_name[transformer] + + return rank_using(model, query, documents, kwargs) + + def create_embedding(transformer): return SentenceTransformer(transformer)