diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index bb97b31e8..224f21caa 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -558,6 +558,26 @@ pub fn embed_batch( } } +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "embed2")] +pub fn embed_batch2<'a>( + transformer: &str, + inputs: Vec<&'a str>, + kwargs: default!(JsonB, "'{}'"), +) -> TableIterator<'a, (name!(text, String), name!(embedding, Vec))> { + let rows = match crate::bindings::transformers::embed(transformer, inputs.clone(), &kwargs.0) { + Ok(rows) => rows, + Err(e) => { + error!("{e}"); + } + }; + TableIterator::new( + inputs.into_iter().zip(rows.into_iter()).map(|(text, embedding)| { + (text.to_string(), embedding) + }), + ) +} + /// Clears the GPU cache. /// /// # Arguments