From 4612b4f655a7eaaddc9c0eb0de2088a62ee22600 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Tue, 5 Mar 2024 15:45:58 +0800 Subject: [PATCH 1/9] Refactor the initialization of GUC parameters. Managing GUC parameters in different places is hard to maintain. This patch organizes GUC definitions in a single place. Also, we use define_xxx_guc() APIs to define these parameters and it will allow us to manage GucContext, GucFlags in future. P.S., the test case test_trusted_model doesn't seem correct. I fixed it in this patch. --- pgml-extension/Cargo.lock | 1 + pgml-extension/Cargo.toml | 1 + pgml-extension/src/bindings/python/mod.rs | 8 +-- .../src/bindings/transformers/whitelist.rs | 55 +++++++++------- pgml-extension/src/config.rs | 66 ++++++++++++++----- pgml-extension/src/lib.rs | 1 + 6 files changed, 88 insertions(+), 44 deletions(-) diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index fbbb90e9d..6a88a8e2a 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1733,6 +1733,7 @@ dependencies = [ "heapless", "indexmap 2.1.0", "itertools 0.12.0", + "lazy_static", "lightgbm", "linfa", "linfa-linear", diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index 362bb017b..f0b710a06 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -49,6 +49,7 @@ serde = { version = "1.0" } serde_json = { version = "1.0", features = ["preserve_order"] } typetag = "0.2" xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" } +lazy_static = "1.4.0" [dev-dependencies] pgrx-tests = "=0.11.2" diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index ba59bef8e..2e48052f3 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -6,11 +6,9 @@ use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::config::get_config; +use crate::config::PGML_VENV; use crate::create_pymodule; -static CONFIG_NAME: &str = "pgml.venv"; - create_pymodule!("/src/bindings/python/python.py"); pub fn activate_venv(venv: &str) -> Result { @@ -23,8 +21,8 @@ pub fn activate_venv(venv: &str) -> Result { } pub fn activate() -> Result { - match get_config(CONFIG_NAME) { - Some(venv) => activate_venv(&venv), + match PGML_VENV.1.get() { + Some(venv) => activate_venv(&venv.to_string_lossy()), None => Ok(false), } } diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index 0194180c0..ac7f3bf4d 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -1,13 +1,11 @@ use anyhow::{bail, Error}; +use pgrx::GucSetting; #[cfg(any(test, feature = "pg_test"))] use pgrx::{pg_schema, pg_test}; use serde_json::Value; +use std::ffi::CStr; -use crate::config::get_config; - -static CONFIG_HF_WHITELIST: &str = "pgml.huggingface_whitelist"; -static CONFIG_HF_TRUST_REMOTE_CODE_BOOL: &str = "pgml.huggingface_trust_remote_code"; -static CONFIG_HF_TRUST_WHITELIST: &str = "pgml.huggingface_trust_remote_code_whitelist"; +use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_WHITELIST, PGML_HF_WHITELIST}; /// Verify that the model in the task JSON is allowed based on the huggingface whitelists. pub fn verify_task(task: &Value) -> Result<(), Error> { @@ -15,33 +13,42 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { Some(model) => model.to_string(), None => return Ok(()), }; - let whitelisted_models = config_csv_list(CONFIG_HF_WHITELIST); + let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST.1); let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); if !model_is_allowed { - bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf"); + bail!( + "model {} is not whitelisted. Consider adding to {} in postgresql.conf", + task_model, + PGML_HF_WHITELIST.0 + ); } let task_trust = get_trust_remote_code(task); - let trust_remote_code = get_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL) - .map(|v| v == "true") - .unwrap_or(true); + let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.1.get(); - let trusted_models = config_csv_list(CONFIG_HF_TRUST_WHITELIST); + let trusted_models = config_csv_list(&PGML_HF_TRUST_WHITELIST.1); let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model); let remote_code_allowed = trust_remote_code && model_is_trusted; if !remote_code_allowed && task_trust == Some(true) { - bail!("model {task_model} is not trusted to run remote code. Consider setting {CONFIG_HF_TRUST_REMOTE_CODE_BOOL} = 'true' or adding {task_model} to {CONFIG_HF_TRUST_WHITELIST}"); + bail!( + "model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}", + task_model, + PGML_HF_TRUST_REMOTE_CODE.0, + task_model, + PGML_HF_TRUST_WHITELIST.0 + ); } Ok(()) } -fn config_csv_list(name: &str) -> Vec { - match get_config(name) { +fn config_csv_list(csv_list: &GucSetting>) -> Vec { + match csv_list.get() { Some(value) => value + .to_string_lossy() .trim_matches('"') .split(',') .filter_map(|s| if s.is_empty() { None } else { Some(s.to_string()) }) @@ -122,7 +129,7 @@ mod tests { #[pg_test] fn test_empty_whitelist() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, "").unwrap(); + set_config(PGML_HF_WHITELIST.0, "").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); @@ -131,12 +138,12 @@ mod tests { #[pg_test] fn test_nonempty_whitelist() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, model).unwrap(); + set_config(PGML_HF_WHITELIST.0, model).unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); - set_config(CONFIG_HF_WHITELIST, "other_model").unwrap(); + set_config(PGML_HF_WHITELIST.0, "other_model").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_err()); @@ -145,8 +152,8 @@ mod tests { #[pg_test] fn test_trusted_model() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, model).unwrap(); - set_config(CONFIG_HF_TRUST_WHITELIST, model).unwrap(); + set_config(PGML_HF_WHITELIST.0, model).unwrap(); + set_config(PGML_HF_TRUST_WHITELIST.0, model).unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); @@ -154,9 +161,9 @@ mod tests { let task_json = format!(json_template!(), model, true); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task(&task).is_ok()); + assert!(verify_task(&task).is_err()); - set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap(); + set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); @@ -169,8 +176,8 @@ mod tests { #[pg_test] fn test_untrusted_model() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, model).unwrap(); - set_config(CONFIG_HF_TRUST_WHITELIST, "other_model").unwrap(); + set_config(PGML_HF_WHITELIST.0, model).unwrap(); + set_config(PGML_HF_TRUST_WHITELIST.0, "other_model").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); @@ -180,7 +187,7 @@ mod tests { let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_err()); - set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap(); + set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index 8f9ade29a..b1f5f43ac 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -1,16 +1,58 @@ +use lazy_static::lazy_static; +use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; use std::ffi::CStr; #[cfg(any(test, feature = "pg_test"))] use pgrx::{pg_schema, pg_test}; -use pgrx_pg_sys::AsPgCStr; -pub fn get_config(name: &str) -> Option { - // SAFETY: name is not null because it is a Rust reference. - let ptr = unsafe { pgrx_pg_sys::GetConfigOption(name.as_pg_cstr(), true, false) }; - (!ptr.is_null()).then(move || { - // SAFETY: assuming pgrx_pg_sys is providing a valid, null terminated pointer. - unsafe { CStr::from_ptr(ptr) }.to_string_lossy().to_string() - }) +lazy_static! { + pub static ref PGML_VENV: (&'static str, GucSetting>) = + ("pgml.venv", GucSetting::>::new(None)); + pub static ref PGML_HF_WHITELIST: (&'static str, GucSetting>) = ( + "pgml.huggingface_whitelist", + GucSetting::>::new(None), + ); + pub static ref PGML_HF_TRUST_REMOTE_CODE: (&'static str, GucSetting) = + ("pgml.huggingface_trust_remote_code", GucSetting::::new(false)); + pub static ref PGML_HF_TRUST_WHITELIST: (&'static str, GucSetting>) = ( + "pgml.huggingface_trust_remote_code_whitelist", + GucSetting::>::new(None), + ); +} + +pub fn initialize_server_params() { + GucRegistry::define_string_guc( + PGML_VENV.0, + "Python's virtual environment path", + "", + &PGML_VENV.1, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_string_guc( + PGML_HF_WHITELIST.0, + "Models allowed to be downloaded from huggingface", + "", + &PGML_HF_WHITELIST.1, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_bool_guc( + PGML_HF_TRUST_REMOTE_CODE.0, + "Whether model can execute remote codes", + "", + &PGML_HF_TRUST_REMOTE_CODE.1, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_string_guc( + PGML_HF_TRUST_WHITELIST.0, + "Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'", + "", + &PGML_HF_TRUST_WHITELIST.1, + GucContext::Userset, + GucFlags::default(), + ); } #[cfg(any(test, feature = "pg_test"))] @@ -26,17 +68,11 @@ pub fn set_config(name: &str, value: &str) -> Result<(), pgrx::spi::Error> { mod tests { use super::*; - #[pg_test] - fn read_config_max_connections() { - let name = "max_connections"; - assert_eq!(get_config(name), Some("100".into())); - } - #[pg_test] fn read_pgml_huggingface_whitelist() { let name = "pgml.huggingface_whitelist"; let value = "meta-llama/Llama-2-7b"; set_config(name, value).unwrap(); - assert_eq!(get_config(name), Some(value.into())); + assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value); } } diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index 6c2884cee..4cc27322e 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -24,6 +24,7 @@ extension_sql_file!("../sql/schema.sql", name = "schema"); #[cfg(not(feature = "use_as_lib"))] #[pg_guard] pub extern "C" fn _PG_init() { + config::initialize_server_params(); bindings::python::activate().expect("Error setting python venv"); orm::project::init(); } From 58e2aeca8c849851d8b233d357afd650aef41926 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Fri, 8 Mar 2024 08:20:40 +0800 Subject: [PATCH 2/9] Allow user to limit the number of threads that OpenMP spawns. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch introduces a new GUC parameter pgml.omp_num_threads to control control the number of threads that OpenMP can spawn. Co-authored-by: Xuebin Su Co-authored-by: Xuebin Su (苏学斌) <12034000+xuebinsu@users.noreply.github.com> --- pgml-extension/src/config.rs | 18 ++++++++++++++++++ pgml-extension/src/lib.rs | 12 +++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index b1f5f43ac..691830cbe 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -18,6 +18,8 @@ lazy_static! { "pgml.huggingface_trust_remote_code_whitelist", GucSetting::>::new(None), ); + pub static ref PGML_OMP_NUM_THREADS: (&'static str, GucSetting) = + ("pgml.omp_num_threads", GucSetting::::new(-1)); } pub fn initialize_server_params() { @@ -53,6 +55,16 @@ pub fn initialize_server_params() { GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_int_guc( + PGML_OMP_NUM_THREADS.0, + "Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid", + "", + &PGML_OMP_NUM_THREADS.1, + -1, + i32::max_value(), + GucContext::Backend, + GucFlags::default(), + ); } #[cfg(any(test, feature = "pg_test"))] @@ -75,4 +87,10 @@ mod tests { set_config(name, value).unwrap(); assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value); } + + #[pg_test] + fn omp_num_threads_cannot_be_set_after_startup() { + let result = std::panic::catch_unwind(|| set_config("pgml.omp_num_threads", "1")); + assert!(result.is_err()); + } } diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index 4cc27322e..1a45f03d3 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -21,10 +21,20 @@ pg_module_magic!(); extension_sql_file!("../sql/schema.sql", name = "schema"); +extern "C" { + fn omp_set_num_threads(num_threads: i32); +} + #[cfg(not(feature = "use_as_lib"))] #[pg_guard] pub extern "C" fn _PG_init() { config::initialize_server_params(); + let omp_num_threads = config::PGML_OMP_NUM_THREADS.1.get(); + if omp_num_threads > 0 { + unsafe { + omp_set_num_threads(omp_num_threads); + } + } bindings::python::activate().expect("Error setting python venv"); orm::project::init(); } @@ -54,7 +64,7 @@ pub mod pg_test { pub fn postgresql_conf_options() -> Vec<&'static str> { // return any postgresql.conf settings that are required for your tests - let mut options = vec!["shared_preload_libraries = 'pgml'"]; + let mut options = vec!["shared_preload_libraries = 'pgml'", "pgml.omp_num_threads = '1'"]; if let Some(venv) = option_env!("PGML_VENV") { let option = format!("pgml.venv = '{venv}'"); options.push(Box::leak(option.into_boxed_str())); From 4e84e083c7ed45d6496aa70489462bd01a227c06 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Fri, 5 Apr 2024 21:10:42 +0800 Subject: [PATCH 3/9] Address review comments. --- pgml-extension/Cargo.lock | 1 - pgml-extension/Cargo.toml | 1 - pgml-extension/src/config.rs | 28 +++++++++++++++------------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index 1bec18628..c9db39e9b 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -1753,7 +1753,6 @@ dependencies = [ "heapless", "indexmap 2.1.0", "itertools 0.12.0", - "lazy_static", "lightgbm", "linfa", "linfa-linear", diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index f6514c134..7aea7ba7c 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -49,7 +49,6 @@ serde = { version = "1.0" } serde_json = { version = "1.0", features = ["preserve_order"] } typetag = "0.2" xgboost = { git = "https://github.com/postgresml/rust-xgboost", branch = "master" } -lazy_static = "1.4.0" [dev-dependencies] pgrx-tests = "=0.11.3" diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index 691830cbe..5114be252 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -1,26 +1,28 @@ -use lazy_static::lazy_static; +use once_cell::sync::Lazy; use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; use std::ffi::CStr; #[cfg(any(test, feature = "pg_test"))] use pgrx::{pg_schema, pg_test}; -lazy_static! { - pub static ref PGML_VENV: (&'static str, GucSetting>) = - ("pgml.venv", GucSetting::>::new(None)); - pub static ref PGML_HF_WHITELIST: (&'static str, GucSetting>) = ( +pub static PGML_VENV: Lazy<(&'static str, GucSetting>)> = + Lazy::new(|| ("pgml.venv", GucSetting::>::new(None))); +pub static PGML_HF_WHITELIST: Lazy<(&'static str, GucSetting>)> = Lazy::new(|| { + ( "pgml.huggingface_whitelist", GucSetting::>::new(None), - ); - pub static ref PGML_HF_TRUST_REMOTE_CODE: (&'static str, GucSetting) = - ("pgml.huggingface_trust_remote_code", GucSetting::::new(false)); - pub static ref PGML_HF_TRUST_WHITELIST: (&'static str, GucSetting>) = ( + ) +}); +pub static PGML_HF_TRUST_REMOTE_CODE: Lazy<(&'static str, GucSetting)> = + Lazy::new(|| ("pgml.huggingface_trust_remote_code", GucSetting::::new(false))); +pub static PGML_HF_TRUST_WHITELIST: Lazy<(&'static str, GucSetting>)> = Lazy::new(|| { + ( "pgml.huggingface_trust_remote_code_whitelist", GucSetting::>::new(None), - ); - pub static ref PGML_OMP_NUM_THREADS: (&'static str, GucSetting) = - ("pgml.omp_num_threads", GucSetting::::new(-1)); -} + ) +}); +pub static PGML_OMP_NUM_THREADS: Lazy<(&'static str, GucSetting)> = + Lazy::new(|| ("pgml.omp_num_threads", GucSetting::::new(-1))); pub fn initialize_server_params() { GucRegistry::define_string_guc( From 2a1853c183a10bfee1a0d3c2f3909a9cdf6cf00b Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Fri, 5 Apr 2024 21:14:08 +0800 Subject: [PATCH 4/9] Address review comments --- pgml-extension/src/config.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index 5114be252..f07a4b371 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -22,7 +22,7 @@ pub static PGML_HF_TRUST_WHITELIST: Lazy<(&'static str, GucSetting)> = - Lazy::new(|| ("pgml.omp_num_threads", GucSetting::::new(-1))); + Lazy::new(|| ("pgml.omp_num_threads", GucSetting::::new(0))); pub fn initialize_server_params() { GucRegistry::define_string_guc( @@ -62,7 +62,7 @@ pub fn initialize_server_params() { "Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid", "", &PGML_OMP_NUM_THREADS.1, - -1, + 0, i32::max_value(), GucContext::Backend, GucFlags::default(), From 3768179e282e1c889ec4c05530d1b1f5fa3ec95c Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Sun, 7 Apr 2024 00:08:24 +0800 Subject: [PATCH 5/9] Simplify global variables. --- pgml-extension/src/bindings/python/mod.rs | 2 +- .../src/bindings/transformers/whitelist.rs | 32 ++++++------ pgml-extension/src/config.rs | 52 ++++++++----------- pgml-extension/src/lib.rs | 2 +- 4 files changed, 40 insertions(+), 48 deletions(-) diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index 2e48052f3..f00de2c7f 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -21,7 +21,7 @@ pub fn activate_venv(venv: &str) -> Result { } pub fn activate() -> Result { - match PGML_VENV.1.get() { + match PGML_VENV.get() { Some(venv) => activate_venv(&venv.to_string_lossy()), None => Ok(false), } diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index ac7f3bf4d..ca1ac9769 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -5,7 +5,7 @@ use pgrx::{pg_schema, pg_test}; use serde_json::Value; use std::ffi::CStr; -use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_WHITELIST, PGML_HF_WHITELIST}; +use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_REMOTE_CODE_WHITELIST, PGML_HF_WHITELIST}; /// Verify that the model in the task JSON is allowed based on the huggingface whitelists. pub fn verify_task(task: &Value) -> Result<(), Error> { @@ -13,21 +13,21 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { Some(model) => model.to_string(), None => return Ok(()), }; - let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST.1); + let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST); let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); if !model_is_allowed { bail!( "model {} is not whitelisted. Consider adding to {} in postgresql.conf", task_model, - PGML_HF_WHITELIST.0 + "pgml.huggingface_whitelist" ); } let task_trust = get_trust_remote_code(task); - let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.1.get(); + let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.get(); - let trusted_models = config_csv_list(&PGML_HF_TRUST_WHITELIST.1); + let trusted_models = config_csv_list(&PGML_HF_TRUST_REMOTE_CODE_WHITELIST); let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model); @@ -36,9 +36,9 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { bail!( "model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}", task_model, - PGML_HF_TRUST_REMOTE_CODE.0, + "pgml.huggingface_trust_remote_code", task_model, - PGML_HF_TRUST_WHITELIST.0 + "pgml.huggingface_trust_remote_code_whitelist", ); } @@ -129,7 +129,7 @@ mod tests { #[pg_test] fn test_empty_whitelist() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(PGML_HF_WHITELIST.0, "").unwrap(); + set_config("pgml.huggingface_whitelist", "").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); @@ -138,12 +138,12 @@ mod tests { #[pg_test] fn test_nonempty_whitelist() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(PGML_HF_WHITELIST.0, model).unwrap(); + set_config("pgml.huggingface_whitelist", model).unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); - set_config(PGML_HF_WHITELIST.0, "other_model").unwrap(); + set_config("pgml.huggingface_whitelist", "other_model").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_err()); @@ -152,8 +152,8 @@ mod tests { #[pg_test] fn test_trusted_model() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(PGML_HF_WHITELIST.0, model).unwrap(); - set_config(PGML_HF_TRUST_WHITELIST.0, model).unwrap(); + set_config("pgml.huggingface_whitelist", model).unwrap(); + set_config("pgml.huggingface_trust_remote_code_whitelist", model).unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); @@ -163,7 +163,7 @@ mod tests { let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_err()); - set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap(); + set_config("pgml.huggingface_trust_remote_code", "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); @@ -176,8 +176,8 @@ mod tests { #[pg_test] fn test_untrusted_model() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(PGML_HF_WHITELIST.0, model).unwrap(); - set_config(PGML_HF_TRUST_WHITELIST.0, "other_model").unwrap(); + set_config("pgml.huggingface_whitelist", model).unwrap(); + set_config("pgml.huggingface_trust_remote_code_whitelist", "other_model").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); @@ -187,7 +187,7 @@ mod tests { let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_err()); - set_config(PGML_HF_TRUST_REMOTE_CODE.0, "true").unwrap(); + set_config("pgml.huggingface_trust_remote_code", "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index f07a4b371..de8e21c45 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -1,67 +1,58 @@ -use once_cell::sync::Lazy; use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; use std::ffi::CStr; #[cfg(any(test, feature = "pg_test"))] use pgrx::{pg_schema, pg_test}; -pub static PGML_VENV: Lazy<(&'static str, GucSetting>)> = - Lazy::new(|| ("pgml.venv", GucSetting::>::new(None))); -pub static PGML_HF_WHITELIST: Lazy<(&'static str, GucSetting>)> = Lazy::new(|| { - ( - "pgml.huggingface_whitelist", - GucSetting::>::new(None), - ) -}); -pub static PGML_HF_TRUST_REMOTE_CODE: Lazy<(&'static str, GucSetting)> = - Lazy::new(|| ("pgml.huggingface_trust_remote_code", GucSetting::::new(false))); -pub static PGML_HF_TRUST_WHITELIST: Lazy<(&'static str, GucSetting>)> = Lazy::new(|| { - ( - "pgml.huggingface_trust_remote_code_whitelist", - GucSetting::>::new(None), - ) -}); -pub static PGML_OMP_NUM_THREADS: Lazy<(&'static str, GucSetting)> = - Lazy::new(|| ("pgml.omp_num_threads", GucSetting::::new(0))); +pub static PGML_VENV: GucSetting> = GucSetting::>::new(None); +pub static PGML_HF_WHITELIST: GucSetting> = GucSetting::>::new(None); +pub static PGML_HF_TRUST_REMOTE_CODE: GucSetting = GucSetting::::new(false); +pub static PGML_HF_TRUST_REMOTE_CODE_WHITELIST: GucSetting> = + GucSetting::>::new(None); +pub static PGML_OMP_NUM_THREADS: GucSetting = GucSetting::::new(0); pub fn initialize_server_params() { GucRegistry::define_string_guc( - PGML_VENV.0, + "pgml.venv", "Python's virtual environment path", "", - &PGML_VENV.1, + &PGML_VENV, GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_string_guc( - PGML_HF_WHITELIST.0, + "pgml.huggingface_whitelist", "Models allowed to be downloaded from huggingface", "", - &PGML_HF_WHITELIST.1, + &PGML_HF_WHITELIST, GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_bool_guc( - PGML_HF_TRUST_REMOTE_CODE.0, + "pgml.huggingface_trust_remote_code", "Whether model can execute remote codes", "", - &PGML_HF_TRUST_REMOTE_CODE.1, + &PGML_HF_TRUST_REMOTE_CODE, GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_string_guc( - PGML_HF_TRUST_WHITELIST.0, + "pgml.huggingface_trust_remote_code_whitelist", "Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'", "", - &PGML_HF_TRUST_WHITELIST.1, + &PGML_HF_TRUST_REMOTE_CODE_WHITELIST, GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_int_guc( - PGML_OMP_NUM_THREADS.0, + "pgml.omp_num_threads", "Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid", "", - &PGML_OMP_NUM_THREADS.1, + &PGML_OMP_NUM_THREADS, 0, i32::max_value(), GucContext::Backend, @@ -87,7 +78,8 @@ mod tests { let name = "pgml.huggingface_whitelist"; let value = "meta-llama/Llama-2-7b"; set_config(name, value).unwrap(); - assert_eq!(PGML_HF_WHITELIST.1.get().unwrap().to_string_lossy(), value); + assert_eq!(PGML_HF_WHITELIST.get().unwrap().to_str().unwrap(), value); + //assert_eq!((&*PGML_HF_WHITELIST).get().unwrap().to_str().unwrap(), value); } #[pg_test] diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index 1a45f03d3..55fa1df08 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -29,7 +29,7 @@ extern "C" { #[pg_guard] pub extern "C" fn _PG_init() { config::initialize_server_params(); - let omp_num_threads = config::PGML_OMP_NUM_THREADS.1.get(); + let omp_num_threads = config::PGML_OMP_NUM_THREADS.get(); if omp_num_threads > 0 { unsafe { omp_set_num_threads(omp_num_threads); From 6e3fe80f9753cfa8984de72c6b38c8aacfa88a32 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Sun, 7 Apr 2024 00:11:48 +0800 Subject: [PATCH 6/9] Fixup!! --- pgml-extension/src/config.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index de8e21c45..c2ec9fbc8 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -79,7 +79,6 @@ mod tests { let value = "meta-llama/Llama-2-7b"; set_config(name, value).unwrap(); assert_eq!(PGML_HF_WHITELIST.get().unwrap().to_str().unwrap(), value); - //assert_eq!((&*PGML_HF_WHITELIST).get().unwrap().to_str().unwrap(), value); } #[pg_test] From d05644aa43f8c4045dbe745676cca85a789f9d91 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Thu, 11 Apr 2024 07:34:06 +0800 Subject: [PATCH 7/9] Update pgml-extension/src/bindings/transformers/whitelist.rs Co-authored-by: kczimm <4733573+kczimm@users.noreply.github.com> --- pgml-extension/src/bindings/transformers/whitelist.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index ca1ac9769..164ab25f4 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -33,13 +33,7 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { let remote_code_allowed = trust_remote_code && model_is_trusted; if !remote_code_allowed && task_trust == Some(true) { - bail!( - "model {} is not trusted to run remote code. Consider setting {} = 'true' or adding {} to {}", - task_model, - "pgml.huggingface_trust_remote_code", - task_model, - "pgml.huggingface_trust_remote_code_whitelist", - ); + bail!("model {task_model} is not trusted to run remote code. Consider setting pgml.huggingface_trust_remote_code = 'true' or adding {task_model} to pgml.huggingface_trust_remote_code_whitelist"); } Ok(()) From 663ed25caf483a3970b8baceccef1f5bf793d448 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Thu, 11 Apr 2024 07:34:18 +0800 Subject: [PATCH 8/9] Update pgml-extension/src/bindings/transformers/whitelist.rs Co-authored-by: kczimm <4733573+kczimm@users.noreply.github.com> --- pgml-extension/src/bindings/transformers/whitelist.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index 164ab25f4..44ab2703f 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -17,11 +17,7 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); if !model_is_allowed { - bail!( - "model {} is not whitelisted. Consider adding to {} in postgresql.conf", - task_model, - "pgml.huggingface_whitelist" - ); + bail!("model {task_model} is not whitelisted. Consider adding to `pgml.huggingface_whitelist` in postgresql.conf"); } let task_trust = get_trust_remote_code(task); From c81c42e2a64588dd5496a27d2e36b65222584dd9 Mon Sep 17 00:00:00 2001 From: Xing Guo Date: Thu, 11 Apr 2024 07:43:21 +0800 Subject: [PATCH 9/9] Address review comments. - Set the minimum value of OMP_NUM_THREADS to 1. - Move PGML_OMP_NUM_THREADS.get() into initialize_server_params(). --- pgml-extension/src/config.rs | 13 +++++++++++-- pgml-extension/src/lib.rs | 10 ---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index c2ec9fbc8..424349ad0 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -9,7 +9,11 @@ pub static PGML_HF_WHITELIST: GucSetting> = GucSetting:: = GucSetting::::new(false); pub static PGML_HF_TRUST_REMOTE_CODE_WHITELIST: GucSetting> = GucSetting::>::new(None); -pub static PGML_OMP_NUM_THREADS: GucSetting = GucSetting::::new(0); +pub static PGML_OMP_NUM_THREADS: GucSetting = GucSetting::::new(1); + +extern "C" { + fn omp_set_num_threads(num_threads: i32); +} pub fn initialize_server_params() { GucRegistry::define_string_guc( @@ -53,11 +57,16 @@ pub fn initialize_server_params() { "Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid", "", &PGML_OMP_NUM_THREADS, - 0, + 1, i32::max_value(), GucContext::Backend, GucFlags::default(), ); + + let omp_num_threads = PGML_OMP_NUM_THREADS.get(); + unsafe { + omp_set_num_threads(omp_num_threads); + } } #[cfg(any(test, feature = "pg_test"))] diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index 55fa1df08..1eab45ae7 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -21,20 +21,10 @@ pg_module_magic!(); extension_sql_file!("../sql/schema.sql", name = "schema"); -extern "C" { - fn omp_set_num_threads(num_threads: i32); -} - #[cfg(not(feature = "use_as_lib"))] #[pg_guard] pub extern "C" fn _PG_init() { config::initialize_server_params(); - let omp_num_threads = config::PGML_OMP_NUM_THREADS.get(); - if omp_num_threads > 0 { - unsafe { - omp_set_num_threads(omp_num_threads); - } - } bindings::python::activate().expect("Error setting python venv"); orm::project::init(); }