🌐 AI搜索 & 代理 主页
Skip to content

Commit 8800a21

Browse files
committed
Clean up batching
1 parent da7df56 commit 8800a21

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

pgml-extension/src/api.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,15 +610,14 @@ pub fn embed_batch(
610610

611611
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
612612
#[pg_extern(immutable, parallel_safe, name = "rank")]
613-
// pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec<pgrx::JsonB> {
614613
pub fn rank(
615614
transformer: &str,
616615
query: &str,
617616
documents: Vec<&str>,
618617
kwargs: default!(JsonB, "'{}'"),
619-
) -> SetOfIterator<'static, pgrx::JsonB> {
618+
) -> TableIterator<'static, (name!(corpus_id, i64), name!(score, f64), name!(text, Option<String>))> {
620619
match crate::bindings::transformers::rank(transformer, query, documents, &kwargs.0) {
621-
Ok(output) => SetOfIterator::new(output.into_iter().map(pgrx::JsonB)),
620+
Ok(output) => TableIterator::new(output.into_iter().map(|x| (x.corpus_id, x.score, x.text))),
622621
Err(e) => error!("{e}"),
623622
}
624623
}

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use anyhow::{anyhow, bail, Context, Result};
77
use pgrx::*;
88
use pyo3::prelude::*;
99
use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple};
10+
use serde::Deserialize;
1011
use serde_json::Value;
1112

1213
use crate::create_pymodule;
@@ -66,8 +67,10 @@ impl FromPyObject<'_> for Json {
6667
if ob.is_none() {
6768
return Ok(Self(serde_json::Value::Null));
6869
}
69-
eprintln!("\n\nTHE OBJ: {:?}\n\n", ob.get_type());
70-
Err(anyhow::anyhow!("Unsupported type for JSON conversion"))?
70+
Err(anyhow::anyhow!(
71+
"Unsupported type for JSON conversion: {:?}",
72+
ob.get_type()
73+
))?
7174
}
7275
}
7376
}
@@ -106,9 +109,21 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
106109
})
107110
}
108111

109-
pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: &serde_json::Value) -> Result<Vec<Value>> {
112+
#[derive(Deserialize)]
113+
pub struct RankResult {
114+
pub corpus_id: i64,
115+
pub score: f64,
116+
pub text: Option<String>,
117+
}
118+
119+
pub fn rank(
120+
transformer: &str,
121+
query: &str,
122+
documents: Vec<&str>,
123+
kwargs: &serde_json::Value,
124+
) -> Result<Vec<RankResult>> {
110125
let kwargs = serde_json::to_string(kwargs)?;
111-
Python::with_gil(|py| -> Result<Vec<Value>> {
126+
Python::with_gil(|py| -> Result<Vec<RankResult>> {
112127
let embed: Py<PyAny> = get_module!(PY_MODULE).getattr(py, "rank").format_traceback(py)?;
113128
let output = embed
114129
.call1(
@@ -125,7 +140,12 @@ pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: &serde
125140
)
126141
.format_traceback(py)?;
127142
let out: Vec<Json> = output.extract(py).format_traceback(py)?;
128-
Ok(out.into_iter().map(|x| x.into()).collect())
143+
out.into_iter()
144+
.map(|x| {
145+
let x: RankResult = serde_json::from_value(x.0)?;
146+
Ok(x)
147+
})
148+
.collect()
129149
})
130150
}
131151

0 commit comments

Comments
 (0)