🌐 AI搜索 & 代理 主页
Skip to content

Commit 108b052

Browse files
authored
Add support for XGBoost eval_metrics and objective (#1103)
1 parent e22134f commit 108b052

File tree

4 files changed

+124
-39
lines changed

4 files changed

+124
-39
lines changed

pgml-extension/examples/regression.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ SELECT * FROM pgml.deployed_models ORDER BY deployed_at DESC LIMIT 5;
106106
-- do a hyperparam search on your favorite algorithm
107107
SELECT pgml.train(
108108
'Diabetes Progression',
109-
algorithm => 'xgboost',
109+
algorithm => 'xgboost',
110+
hyperparams => '{"eval_metric": "rmse"}'::JSONB,
110111
search => 'grid',
111112
search_params => '{
112113
"max_depth": [1, 2],

pgml-extension/src/bindings/xgboost.rs

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters {
128128
},
129129
"max_leaves" => params.max_leaves(value.as_u64().unwrap() as u32),
130130
"max_bin" => params.max_bin(value.as_u64().unwrap() as u32),
131-
"booster" | "n_estimators" | "boost_rounds" => &mut params, // Valid but not relevant to this section
131+
"booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => {
132+
&mut params
133+
} // Valid but not relevant to this section
132134
"nthread" => &mut params,
133135
"random_state" => &mut params,
134136
_ => panic!("Unknown hyperparameter {:?}: {:?}", key, value),
@@ -152,6 +154,52 @@ pub fn fit_classification(
152154
)
153155
}
154156

157+
fn eval_metric_from_string(name: &str) -> learning::EvaluationMetric {
158+
match name {
159+
"rmse" => learning::EvaluationMetric::RMSE,
160+
"mae" => learning::EvaluationMetric::MAE,
161+
"logloss" => learning::EvaluationMetric::LogLoss,
162+
"merror" => learning::EvaluationMetric::MultiClassErrorRate,
163+
"mlogloss" => learning::EvaluationMetric::MultiClassLogLoss,
164+
"auc" => learning::EvaluationMetric::AUC,
165+
"ndcg" => learning::EvaluationMetric::NDCG,
166+
"ndcg-" => learning::EvaluationMetric::NDCGNegative,
167+
"map" => learning::EvaluationMetric::MAP,
168+
"map-" => learning::EvaluationMetric::MAPNegative,
169+
"poisson-nloglik" => learning::EvaluationMetric::PoissonLogLoss,
170+
"gamma-nloglik" => learning::EvaluationMetric::GammaLogLoss,
171+
"cox-nloglik" => learning::EvaluationMetric::CoxLogLoss,
172+
"gamma-deviance" => learning::EvaluationMetric::GammaDeviance,
173+
"tweedie-nloglik" => learning::EvaluationMetric::TweedieLogLoss,
174+
_ => error!("Unknown eval_metric: {:?}", name),
175+
}
176+
}
177+
178+
fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective {
179+
match name {
180+
"reg:linear" => learning::Objective::RegLinear,
181+
"reg:logistic" => learning::Objective::RegLogistic,
182+
"binary:logistic" => learning::Objective::BinaryLogistic,
183+
"binary:logitraw" => learning::Objective::BinaryLogisticRaw,
184+
"gpu:reg:linear" => learning::Objective::GpuRegLinear,
185+
"gpu:reg:logistic" => learning::Objective::GpuRegLogistic,
186+
"gpu:binary:logistic" => learning::Objective::GpuBinaryLogistic,
187+
"gpu:binary:logitraw" => learning::Objective::GpuBinaryLogisticRaw,
188+
"count:poisson" => learning::Objective::CountPoisson,
189+
"survival:cox" => learning::Objective::SurvivalCox,
190+
"multi:softmax" => {
191+
learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap())
192+
}
193+
"multi:softprob" => {
194+
learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap())
195+
}
196+
"rank:pairwise" => learning::Objective::RankPairwise,
197+
"reg:gamma" => learning::Objective::RegGamma,
198+
"reg:tweedie" => learning::Objective::RegTweedie(Some(dataset.num_distinct_labels as f32)),
199+
_ => error!("Unknown objective: {:?}", name),
200+
}
201+
}
202+
155203
fn fit(
156204
dataset: &Dataset,
157205
hyperparams: &Hyperparams,
@@ -170,14 +218,40 @@ fn fit(
170218
Some(value) => value.as_u64().unwrap(),
171219
None => 0,
172220
};
173-
let learning_params = learning::LearningTaskParametersBuilder::default()
174-
.objective(objective)
221+
let eval_metrics = match hyperparams.get("eval_metric") {
222+
Some(metrics) => {
223+
if metrics.is_array() {
224+
learning::Metrics::Custom(
225+
metrics
226+
.as_array()
227+
.unwrap()
228+
.iter()
229+
.map(|metric| eval_metric_from_string(metric.as_str().unwrap()))
230+
.collect(),
231+
)
232+
} else {
233+
learning::Metrics::Custom(Vec::from([eval_metric_from_string(
234+
metrics.as_str().unwrap(),
235+
)]))
236+
}
237+
}
238+
None => learning::Metrics::Auto,
239+
};
240+
let learning_params = match learning::LearningTaskParametersBuilder::default()
241+
.objective(match hyperparams.get("objective") {
242+
Some(value) => objective_from_string(value.as_str().unwrap(), dataset),
243+
None => objective,
244+
})
245+
.eval_metrics(eval_metrics)
175246
.seed(seed)
176247
.build()
177-
.unwrap();
248+
{
249+
Ok(params) => params,
250+
Err(e) => error!("Failed to parse learning params:\n\n{}", e),
251+
};
178252

179253
// overall configuration for Booster
180-
let booster_params = BoosterParametersBuilder::default()
254+
let booster_params = match BoosterParametersBuilder::default()
181255
.learning_params(learning_params)
182256
.booster_type(match hyperparams.get("booster") {
183257
Some(value) => match value.as_str().unwrap() {
@@ -195,7 +269,10 @@ fn fit(
195269
)
196270
.verbose(true)
197271
.build()
198-
.unwrap();
272+
{
273+
Ok(params) => params,
274+
Err(e) => error!("Failed to configure booster:\n\n{}", e),
275+
};
199276

200277
let mut builder = TrainingParametersBuilder::default();
201278
// number of training iterations is aliased
@@ -207,18 +284,24 @@ fn fit(
207284
},
208285
};
209286

210-
let params = builder
287+
let params = match builder
211288
// dataset to train with
212289
.dtrain(&dtrain)
213290
// optional datasets to evaluate against in each iteration
214291
.evaluation_sets(Some(evaluation_sets))
215292
// model parameters
216293
.booster_params(booster_params)
217294
.build()
218-
.unwrap();
295+
{
296+
Ok(params) => params,
297+
Err(e) => error!("Failed to create training parameters:\n\n{}", e),
298+
};
219299

220300
// train model, and print evaluation data
221-
let booster = Booster::train(&params).unwrap();
301+
let booster = match Booster::train(&params) {
302+
Ok(booster) => booster,
303+
Err(e) => error!("Failed to train model:\n\n{}", e),
304+
};
222305

223306
Ok(Box::new(Estimator { estimator: booster }))
224307
}

pgml-extension/src/orm/model.rs

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use anyhow::{anyhow, bail, Result};
22
use parking_lot::Mutex;
33
use std::collections::HashMap;
44
use std::fmt::{Display, Error, Formatter};
5+
use std::num::NonZeroUsize;
56
use std::str::FromStr;
67
use std::sync::Arc;
78
use std::time::Instant;
@@ -962,16 +963,13 @@ impl Model {
962963
pub fn numeric_encode_features(&self, rows: &[pgrx::datum::AnyElement]) -> Vec<f32> {
963964
// TODO handle FLOAT4[] as if it were pgrx::datum::AnyElement, skipping all this, and going straight to predict
964965
let mut features = Vec::new(); // TODO pre-allocate space
965-
let columns = &self.snapshot.columns;
966966
for row in rows {
967967
match row.oid() {
968968
pgrx_pg_sys::RECORDOID => {
969969
let tuple = unsafe { PgHeapTuple::from_composite_datum(row.datum()) };
970-
for index in 1..tuple.len() + 1 {
971-
let column = &columns[index - 1];
972-
let attribute = tuple
973-
.get_attribute_by_index(index.try_into().unwrap())
974-
.unwrap();
970+
for (i, column) in self.snapshot.features().enumerate() {
971+
let index = NonZeroUsize::new(i + 1).unwrap();
972+
let attribute = tuple.get_attribute_by_index(index).unwrap();
975973
match &column.statistics.categories {
976974
Some(_categories) => {
977975
let key = match attribute.atttypid {
@@ -982,14 +980,14 @@ impl Model {
982980
| pgrx_pg_sys::VARCHAROID
983981
| pgrx_pg_sys::BPCHAROID => {
984982
let element: Result<Option<String>, TryFromDatumError> =
985-
tuple.get_by_index(index.try_into().unwrap());
983+
tuple.get_by_index(index);
986984
element
987985
.unwrap()
988986
.unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string())
989987
}
990988
pgrx_pg_sys::BOOLOID => {
991989
let element: Result<Option<bool>, TryFromDatumError> =
992-
tuple.get_by_index(index.try_into().unwrap());
990+
tuple.get_by_index(index);
993991
element
994992
.unwrap()
995993
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
@@ -998,7 +996,7 @@ impl Model {
998996
}
999997
pgrx_pg_sys::INT2OID => {
1000998
let element: Result<Option<i16>, TryFromDatumError> =
1001-
tuple.get_by_index(index.try_into().unwrap());
999+
tuple.get_by_index(index);
10021000
element
10031001
.unwrap()
10041002
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
@@ -1007,7 +1005,7 @@ impl Model {
10071005
}
10081006
pgrx_pg_sys::INT4OID => {
10091007
let element: Result<Option<i32>, TryFromDatumError> =
1010-
tuple.get_by_index(index.try_into().unwrap());
1008+
tuple.get_by_index(index);
10111009
element
10121010
.unwrap()
10131011
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
@@ -1016,7 +1014,7 @@ impl Model {
10161014
}
10171015
pgrx_pg_sys::INT8OID => {
10181016
let element: Result<Option<i64>, TryFromDatumError> =
1019-
tuple.get_by_index(index.try_into().unwrap());
1017+
tuple.get_by_index(index);
10201018
element
10211019
.unwrap()
10221020
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
@@ -1025,7 +1023,7 @@ impl Model {
10251023
}
10261024
pgrx_pg_sys::FLOAT4OID => {
10271025
let element: Result<Option<f32>, TryFromDatumError> =
1028-
tuple.get_by_index(index.try_into().unwrap());
1026+
tuple.get_by_index(index);
10291027
element
10301028
.unwrap()
10311029
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
@@ -1034,7 +1032,7 @@ impl Model {
10341032
}
10351033
pgrx_pg_sys::FLOAT8OID => {
10361034
let element: Result<Option<f64>, TryFromDatumError> =
1037-
tuple.get_by_index(index.try_into().unwrap());
1035+
tuple.get_by_index(index);
10381036
element
10391037
.unwrap()
10401038
.map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| {
@@ -1056,79 +1054,79 @@ impl Model {
10561054
}
10571055
pgrx_pg_sys::BOOLOID => {
10581056
let element: Result<Option<bool>, TryFromDatumError> =
1059-
tuple.get_by_index(index.try_into().unwrap());
1057+
tuple.get_by_index(index);
10601058
features.push(
10611059
element.unwrap().map_or(f32::NAN, |v| v as u8 as f32),
10621060
);
10631061
}
10641062
pgrx_pg_sys::INT2OID => {
10651063
let element: Result<Option<i16>, TryFromDatumError> =
1066-
tuple.get_by_index(index.try_into().unwrap());
1064+
tuple.get_by_index(index);
10671065
features
10681066
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
10691067
}
10701068
pgrx_pg_sys::INT4OID => {
10711069
let element: Result<Option<i32>, TryFromDatumError> =
1072-
tuple.get_by_index(index.try_into().unwrap());
1070+
tuple.get_by_index(index);
10731071
features
10741072
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
10751073
}
10761074
pgrx_pg_sys::INT8OID => {
10771075
let element: Result<Option<i64>, TryFromDatumError> =
1078-
tuple.get_by_index(index.try_into().unwrap());
1076+
tuple.get_by_index(index);
10791077
features
10801078
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
10811079
}
10821080
pgrx_pg_sys::FLOAT4OID => {
10831081
let element: Result<Option<f32>, TryFromDatumError> =
1084-
tuple.get_by_index(index.try_into().unwrap());
1082+
tuple.get_by_index(index);
10851083
features.push(element.unwrap().map_or(f32::NAN, |v| v));
10861084
}
10871085
pgrx_pg_sys::FLOAT8OID => {
10881086
let element: Result<Option<f64>, TryFromDatumError> =
1089-
tuple.get_by_index(index.try_into().unwrap());
1087+
tuple.get_by_index(index);
10901088
features
10911089
.push(element.unwrap().map_or(f32::NAN, |v| v as f32));
10921090
}
10931091
// TODO handle NULL to NaN for arrays
10941092
pgrx_pg_sys::BOOLARRAYOID => {
10951093
let element: Result<Option<Vec<bool>>, TryFromDatumError> =
1096-
tuple.get_by_index(index.try_into().unwrap());
1094+
tuple.get_by_index(index);
10971095
for j in element.as_ref().unwrap().as_ref().unwrap() {
10981096
features.push(*j as i8 as f32);
10991097
}
11001098
}
11011099
pgrx_pg_sys::INT2ARRAYOID => {
11021100
let element: Result<Option<Vec<i16>>, TryFromDatumError> =
1103-
tuple.get_by_index(index.try_into().unwrap());
1101+
tuple.get_by_index(index);
11041102
for j in element.as_ref().unwrap().as_ref().unwrap() {
11051103
features.push(*j as f32);
11061104
}
11071105
}
11081106
pgrx_pg_sys::INT4ARRAYOID => {
11091107
let element: Result<Option<Vec<i32>>, TryFromDatumError> =
1110-
tuple.get_by_index(index.try_into().unwrap());
1108+
tuple.get_by_index(index);
11111109
for j in element.as_ref().unwrap().as_ref().unwrap() {
11121110
features.push(*j as f32);
11131111
}
11141112
}
11151113
pgrx_pg_sys::INT8ARRAYOID => {
11161114
let element: Result<Option<Vec<i64>>, TryFromDatumError> =
1117-
tuple.get_by_index(index.try_into().unwrap());
1115+
tuple.get_by_index(index);
11181116
for j in element.as_ref().unwrap().as_ref().unwrap() {
11191117
features.push(*j as f32);
11201118
}
11211119
}
11221120
pgrx_pg_sys::FLOAT4ARRAYOID => {
11231121
let element: Result<Option<Vec<f32>>, TryFromDatumError> =
1124-
tuple.get_by_index(index.try_into().unwrap());
1122+
tuple.get_by_index(index);
11251123
for j in element.as_ref().unwrap().as_ref().unwrap() {
11261124
features.push(*j);
11271125
}
11281126
}
11291127
pgrx_pg_sys::FLOAT8ARRAYOID => {
11301128
let element: Result<Option<Vec<f64>>, TryFromDatumError> =
1131-
tuple.get_by_index(index.try_into().unwrap());
1129+
tuple.get_by_index(index);
11321130
for j in element.as_ref().unwrap().as_ref().unwrap() {
11331131
features.push(*j as f32);
11341132
}

0 commit comments

Comments
 (0)