@@ -7,6 +7,7 @@ use anyhow::{anyhow, bail, Context, Result};
77use pgrx:: * ;
88use pyo3:: prelude:: * ;
99use pyo3:: types:: { PyBool , PyDict , PyFloat , PyInt , PyList , PyString , PyTuple } ;
10+ use serde:: Deserialize ;
1011use serde_json:: Value ;
1112
1213use crate :: create_pymodule;
@@ -66,8 +67,10 @@ impl FromPyObject<'_> for Json {
6667 if ob. is_none ( ) {
6768 return Ok ( Self ( serde_json:: Value :: Null ) ) ;
6869 }
69- eprintln ! ( "\n \n THE OBJ: {:?}\n \n " , ob. get_type( ) ) ;
70- Err ( anyhow:: anyhow!( "Unsupported type for JSON conversion" ) ) ?
70+ Err ( anyhow:: anyhow!(
71+ "Unsupported type for JSON conversion: {:?}" ,
72+ ob. get_type( )
73+ ) ) ?
7174 }
7275 }
7376}
@@ -106,9 +109,21 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
106109 } )
107110}
108111
109- pub fn rank ( transformer : & str , query : & str , documents : Vec < & str > , kwargs : & serde_json:: Value ) -> Result < Vec < Value > > {
112+ #[ derive( Deserialize ) ]
113+ pub struct RankResult {
114+ pub corpus_id : i64 ,
115+ pub score : f64 ,
116+ pub text : Option < String > ,
117+ }
118+
119+ pub fn rank (
120+ transformer : & str ,
121+ query : & str ,
122+ documents : Vec < & str > ,
123+ kwargs : & serde_json:: Value ,
124+ ) -> Result < Vec < RankResult > > {
110125 let kwargs = serde_json:: to_string ( kwargs) ?;
111- Python :: with_gil ( |py| -> Result < Vec < Value > > {
126+ Python :: with_gil ( |py| -> Result < Vec < RankResult > > {
112127 let embed: Py < PyAny > = get_module ! ( PY_MODULE ) . getattr ( py, "rank" ) . format_traceback ( py) ?;
113128 let output = embed
114129 . call1 (
@@ -125,7 +140,12 @@ pub fn rank(transformer: &str, query: &str, documents: Vec<&str>, kwargs: &serde
125140 )
126141 . format_traceback ( py) ?;
127142 let out: Vec < Json > = output. extract ( py) . format_traceback ( py) ?;
128- Ok ( out. into_iter ( ) . map ( |x| x. into ( ) ) . collect ( ) )
143+ out. into_iter ( )
144+ . map ( |x| {
145+ let x: RankResult = serde_json:: from_value ( x. 0 ) ?;
146+ Ok ( x)
147+ } )
148+ . collect ( )
129149 } )
130150}
131151
0 commit comments