@@ -6,7 +6,7 @@ use std::{collections::HashMap, path::Path};
66use anyhow:: { anyhow, bail, Context , Result } ;
77use pgrx:: * ;
88use pyo3:: prelude:: * ;
9- use pyo3:: types:: PyTuple ;
9+ use pyo3:: types:: { PyBool , PyDict , PyFloat , PyInt , PyList , PyString , PyTuple } ;
1010use serde_json:: Value ;
1111
1212use crate :: create_pymodule;
@@ -21,6 +21,57 @@ pub use transform::*;
2121
2222create_pymodule ! ( "/src/bindings/transformers/transformers.py" ) ;
2323
24+ // Need a wrapper so we can implement traits for it
25+ struct Json ( Value ) ;
26+
27+ impl From < Json > for Value {
28+ fn from ( value : Json ) -> Self {
29+ value. 0
30+ }
31+ }
32+
33+ impl FromPyObject < ' _ > for Json {
34+ fn extract ( ob : & PyAny ) -> PyResult < Self > {
35+ if ob. is_instance_of :: < PyDict > ( ) {
36+ let dict: & PyDict = ob. downcast ( ) ?;
37+ let mut json = serde_json:: Map :: new ( ) ;
38+ for ( key, value) in dict. iter ( ) {
39+ let value = Json :: extract ( value) ?;
40+ json. insert ( String :: extract ( key) ?, value. 0 ) ;
41+ }
42+ Ok ( Self ( serde_json:: Value :: Object ( json) ) )
43+ } else if ob. is_instance_of :: < PyBool > ( ) {
44+ let value = bool:: extract ( ob) ?;
45+ Ok ( Self ( serde_json:: Value :: Bool ( value) ) )
46+ } else if ob. is_instance_of :: < PyInt > ( ) {
47+ let value = i64:: extract ( ob) ?;
48+ Ok ( Self ( serde_json:: Value :: Number ( value. into ( ) ) ) )
49+ } else if ob. is_instance_of :: < PyFloat > ( ) {
50+ let value = f64:: extract ( ob) ?;
51+ let value =
52+ serde_json:: value:: Number :: from_f64 ( value) . context ( "Could not convert f64 to serde_json::Number" ) ?;
53+ Ok ( Self ( serde_json:: Value :: Number ( value) ) )
54+ } else if ob. is_instance_of :: < PyString > ( ) {
55+ let value = String :: extract ( ob) ?;
56+ Ok ( Self ( serde_json:: Value :: String ( value) ) )
57+ } else if ob. is_instance_of :: < PyList > ( ) {
58+ let value = ob. downcast :: < PyList > ( ) ?;
59+ let mut json_values = Vec :: new ( ) ;
60+ for v in value {
61+ let v = v. extract :: < Json > ( ) ?;
62+ json_values. push ( v. 0 ) ;
63+ }
64+ Ok ( Self ( serde_json:: Value :: Array ( json_values) ) )
65+ } else {
66+ if ob. is_none ( ) {
67+ return Ok ( Self ( serde_json:: Value :: Null ) ) ;
68+ }
69+ eprintln ! ( "\n \n THE OBJ: {:?}\n \n " , ob. get_type( ) ) ;
70+ Err ( anyhow:: anyhow!( "Unsupported type for JSON conversion" ) ) ?
71+ }
72+ }
73+ }
74+
2475pub fn get_model_from ( task : & Value ) -> Result < String > {
2576 Python :: with_gil ( |py| -> Result < String > {
2677 let get_model_from = get_module ! ( PY_MODULE )
@@ -55,6 +106,29 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
55106 } )
56107}
57108
109+ pub fn rank ( transformer : & str , query : & str , documents : Vec < & str > , kwargs : & serde_json:: Value ) -> Result < Vec < Value > > {
110+ let kwargs = serde_json:: to_string ( kwargs) ?;
111+ Python :: with_gil ( |py| -> Result < Vec < Value > > {
112+ let embed: Py < PyAny > = get_module ! ( PY_MODULE ) . getattr ( py, "rank" ) . format_traceback ( py) ?;
113+ let output = embed
114+ . call1 (
115+ py,
116+ PyTuple :: new (
117+ py,
118+ & [
119+ transformer. to_string ( ) . into_py ( py) ,
120+ query. into_py ( py) ,
121+ documents. into_py ( py) ,
122+ kwargs. into_py ( py) ,
123+ ] ,
124+ ) ,
125+ )
126+ . format_traceback ( py) ?;
127+ let out: Vec < Json > = output. extract ( py) . format_traceback ( py) ?;
128+ Ok ( out. into_iter ( ) . map ( |x| x. into ( ) ) . collect ( ) )
129+ } )
130+ }
131+
58132pub fn finetune_text_classification (
59133 task : & Task ,
60134 dataset : TextClassificationDataset ,
0 commit comments