@@ -13,16 +13,25 @@ use serde_json::json;
1313use crate :: bindings:: sklearn:: package_version;
1414use crate :: orm:: * ;
1515
16+ macro_rules! unwrap_or_error {
17+ ( $i: expr) => {
18+ match $i {
19+ Ok ( v) => v,
20+ Err ( e) => error!( "{e}" ) ,
21+ }
22+ } ;
23+ }
24+
1625#[ cfg( feature = "python" ) ]
1726#[ pg_extern]
1827pub fn activate_venv ( venv : & str ) -> bool {
19- crate :: bindings:: venv:: activate_venv ( venv)
28+ unwrap_or_error ! ( crate :: bindings:: venv:: activate_venv( venv) )
2029}
2130
2231#[ cfg( feature = "python" ) ]
2332#[ pg_extern( immutable, parallel_safe) ]
2433pub fn validate_python_dependencies ( ) -> bool {
25- crate :: bindings:: venv:: activate ( ) ;
34+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
2635
2736 Python :: with_gil ( |py| {
2837 let sys = PyModule :: import ( py, "sys" ) . unwrap ( ) ;
@@ -40,13 +49,12 @@ pub fn validate_python_dependencies() -> bool {
4049 }
4150 } ) ;
4251
43- info ! (
44- "Scikit-learn {}, XGBoost {}, LightGBM {}, NumPy {}" ,
45- package_version( "sklearn" ) ,
46- package_version( "xgboost" ) ,
47- package_version( "lightgbm" ) ,
48- package_version( "numpy" ) ,
49- ) ;
52+ let sklearn = unwrap_or_error ! ( package_version( "sklearn" ) ) ;
53+ let xgboost = unwrap_or_error ! ( package_version( "xgboost" ) ) ;
54+ let lightgbm = unwrap_or_error ! ( package_version( "lightgbm" ) ) ;
55+ let numpy = unwrap_or_error ! ( package_version( "numpy" ) ) ;
56+
57+ info ! ( "Scikit-learn {sklearn}, XGBoost {xgboost}, LightGBM {lightgbm}, NumPy {numpy}" , ) ;
5058
5159 true
5260}
@@ -58,8 +66,8 @@ pub fn validate_python_dependencies() {}
5866#[ cfg( feature = "python" ) ]
5967#[ pg_extern]
6068pub fn python_package_version ( name : & str ) -> String {
61- crate :: bindings:: venv:: activate ( ) ;
62- package_version ( name)
69+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
70+ unwrap_or_error ! ( package_version( name) )
6371}
6472
6573#[ cfg( not( feature = "python" ) ) ]
@@ -71,9 +79,9 @@ pub fn python_package_version(name: &str) {
7179#[ cfg( feature = "python" ) ]
7280#[ pg_extern]
7381pub fn python_pip_freeze ( ) -> TableIterator < ' static , ( name ! ( package, String ) , ) > {
74- crate :: bindings:: venv:: activate ( ) ;
82+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
7583
76- let packages = crate :: bindings:: venv:: freeze ( )
84+ let packages = unwrap_or_error ! ( crate :: bindings:: venv:: freeze( ) )
7785 . into_iter ( )
7886 . map ( |package| ( package, ) ) ;
7987
@@ -99,7 +107,7 @@ pub fn validate_shared_library() {
99107#[ cfg( feature = "python" ) ]
100108#[ pg_extern]
101109fn python_version ( ) -> String {
102- crate :: bindings:: venv:: activate ( ) ;
110+ unwrap_or_error ! ( crate :: bindings:: venv:: activate( ) ) ;
103111 let mut version = String :: new ( ) ;
104112
105113 Python :: with_gil ( |py| {
@@ -479,27 +487,31 @@ fn predict_row(project_name: &str, row: pgrx::datum::AnyElement) -> f32 {
479487
480488#[ pg_extern( immutable, parallel_safe, strict, name = "predict" ) ]
481489fn predict_model ( model_id : i64 , features : Vec < f32 > ) -> f32 {
482- Model :: find_cached ( model_id) . predict ( & features)
490+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
491+ unwrap_or_error ! ( model. predict( & features) )
483492}
484493
485494#[ pg_extern( immutable, parallel_safe, strict, name = "predict_proba" ) ]
486495fn predict_model_proba ( model_id : i64 , features : Vec < f32 > ) -> Vec < f32 > {
487- Model :: find_cached ( model_id) . predict_proba ( & features)
496+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
497+ unwrap_or_error ! ( model. predict_proba( & features) )
488498}
489499
490500#[ pg_extern( immutable, parallel_safe, strict, name = "predict_joint" ) ]
491501fn predict_model_joint ( model_id : i64 , features : Vec < f32 > ) -> Vec < f32 > {
492- Model :: find_cached ( model_id) . predict_joint ( & features)
502+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
503+ unwrap_or_error ! ( model. predict_joint( & features) )
493504}
494505
495506#[ pg_extern( immutable, parallel_safe, strict, name = "predict_batch" ) ]
496507fn predict_model_batch ( model_id : i64 , features : Vec < f32 > ) -> Vec < f32 > {
497- Model :: find_cached ( model_id) . predict_batch ( & features)
508+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
509+ unwrap_or_error ! ( model. predict_batch( & features) )
498510}
499511
500512#[ pg_extern( immutable, parallel_safe, strict, name = "predict" ) ]
501513fn predict_model_row ( model_id : i64 , row : pgrx:: datum:: AnyElement ) -> f32 {
502- let model = Model :: find_cached ( model_id) ;
514+ let model = unwrap_or_error ! ( Model :: find_cached( model_id) ) ;
503515 let snapshot = & model. snapshot ;
504516 let numeric_encoded_features = model. numeric_encode_features ( & [ row] ) ;
505517 let features_width = snapshot. features_width ( ) ;
@@ -514,7 +526,7 @@ fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 {
514526 let column = & snapshot. columns [ position. column_position - 1 ] ;
515527 column. preprocess ( & data, & mut processed, features_width, position. row_position ) ;
516528 } ) ;
517- model. predict ( & processed)
529+ unwrap_or_error ! ( model. predict( & processed) )
518530}
519531
520532#[ pg_extern]
@@ -617,7 +629,11 @@ pub fn chunk(
617629 text : & str ,
618630 kwargs : default ! ( JsonB , "'{}'" ) ,
619631) -> TableIterator < ' static , ( name ! ( chunk_index, i64 ) , name ! ( chunk, String ) ) > {
620- let chunks = crate :: bindings:: langchain:: chunk ( splitter, text, & kwargs. 0 ) ;
632+ let chunks = match crate :: bindings:: langchain:: chunk ( splitter, text, & kwargs. 0 ) {
633+ Ok ( chunks) => chunks,
634+ Err ( e) => error ! ( "{e}" ) ,
635+ } ;
636+
621637 let chunks = chunks
622638 . into_iter ( )
623639 . enumerate ( )
@@ -838,28 +854,23 @@ fn tune(
838854#[ cfg( feature = "python" ) ]
839855#[ pg_extern( name = "sklearn_f1_score" ) ]
840856pub fn sklearn_f1_score ( ground_truth : Vec < f32 > , y_hat : Vec < f32 > ) -> f32 {
841- crate :: bindings:: sklearn:: f1 ( & ground_truth, & y_hat)
857+ unwrap_or_error ! ( crate :: bindings:: sklearn:: f1( & ground_truth, & y_hat) )
842858}
843859
844860#[ cfg( feature = "python" ) ]
845861#[ pg_extern( name = "sklearn_r2_score" ) ]
846862pub fn sklearn_r2_score ( ground_truth : Vec < f32 > , y_hat : Vec < f32 > ) -> f32 {
847- crate :: bindings:: sklearn:: r2 ( & ground_truth, & y_hat)
863+ unwrap_or_error ! ( crate :: bindings:: sklearn:: r2( & ground_truth, & y_hat) )
848864}
849865
850866#[ cfg( feature = "python" ) ]
851867#[ pg_extern( name = "sklearn_regression_metrics" ) ]
852868pub fn sklearn_regression_metrics ( ground_truth : Vec < f32 > , y_hat : Vec < f32 > ) -> JsonB {
853- JsonB (
854- serde_json:: from_str (
855- & serde_json:: to_string ( & crate :: bindings:: sklearn:: regression_metrics (
856- & ground_truth,
857- & y_hat,
858- ) )
859- . unwrap ( ) ,
860- )
861- . unwrap ( ) ,
862- )
869+ let metrics = unwrap_or_error ! ( crate :: bindings:: sklearn:: regression_metrics(
870+ & ground_truth,
871+ & y_hat,
872+ ) ) ;
873+ JsonB ( json ! ( metrics) )
863874}
864875
865876#[ cfg( feature = "python" ) ]
@@ -869,17 +880,13 @@ pub fn sklearn_classification_metrics(
869880 y_hat : Vec < f32 > ,
870881 num_classes : i64 ,
871882) -> JsonB {
872- JsonB (
873- serde_json:: from_str (
874- & serde_json:: to_string ( & crate :: bindings:: sklearn:: classification_metrics (
875- & ground_truth,
876- & y_hat,
877- num_classes as usize ,
878- ) )
879- . unwrap ( ) ,
880- )
881- . unwrap ( ) ,
882- )
883+ let metrics = unwrap_or_error ! ( crate :: bindings:: sklearn:: classification_metrics(
884+ & ground_truth,
885+ & y_hat,
886+ num_classes as _
887+ ) ) ;
888+
889+ JsonB ( json ! ( metrics) )
883890}
884891
885892#[ pg_extern]
0 commit comments