diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index ba59bef8e..f00de2c7f 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -6,11 +6,9 @@ use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::config::get_config; +use crate::config::PGML_VENV; use crate::create_pymodule; -static CONFIG_NAME: &str = "pgml.venv"; - create_pymodule!("/src/bindings/python/python.py"); pub fn activate_venv(venv: &str) -> Result { @@ -23,8 +21,8 @@ pub fn activate_venv(venv: &str) -> Result { } pub fn activate() -> Result { - match get_config(CONFIG_NAME) { - Some(venv) => activate_venv(&venv), + match PGML_VENV.get() { + Some(venv) => activate_venv(&venv.to_string_lossy()), None => Ok(false), } } diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index 0194180c0..44ab2703f 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -1,13 +1,11 @@ use anyhow::{bail, Error}; +use pgrx::GucSetting; #[cfg(any(test, feature = "pg_test"))] use pgrx::{pg_schema, pg_test}; use serde_json::Value; +use std::ffi::CStr; -use crate::config::get_config; - -static CONFIG_HF_WHITELIST: &str = "pgml.huggingface_whitelist"; -static CONFIG_HF_TRUST_REMOTE_CODE_BOOL: &str = "pgml.huggingface_trust_remote_code"; -static CONFIG_HF_TRUST_WHITELIST: &str = "pgml.huggingface_trust_remote_code_whitelist"; +use crate::config::{PGML_HF_TRUST_REMOTE_CODE, PGML_HF_TRUST_REMOTE_CODE_WHITELIST, PGML_HF_WHITELIST}; /// Verify that the model in the task JSON is allowed based on the huggingface whitelists. pub fn verify_task(task: &Value) -> Result<(), Error> { @@ -15,33 +13,32 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { Some(model) => model.to_string(), None => return Ok(()), }; - let whitelisted_models = config_csv_list(CONFIG_HF_WHITELIST); + let whitelisted_models = config_csv_list(&PGML_HF_WHITELIST); let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); if !model_is_allowed { - bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf"); + bail!("model {task_model} is not whitelisted. Consider adding to `pgml.huggingface_whitelist` in postgresql.conf"); } let task_trust = get_trust_remote_code(task); - let trust_remote_code = get_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL) - .map(|v| v == "true") - .unwrap_or(true); + let trust_remote_code = PGML_HF_TRUST_REMOTE_CODE.get(); - let trusted_models = config_csv_list(CONFIG_HF_TRUST_WHITELIST); + let trusted_models = config_csv_list(&PGML_HF_TRUST_REMOTE_CODE_WHITELIST); let model_is_trusted = trusted_models.is_empty() || trusted_models.contains(&task_model); let remote_code_allowed = trust_remote_code && model_is_trusted; if !remote_code_allowed && task_trust == Some(true) { - bail!("model {task_model} is not trusted to run remote code. Consider setting {CONFIG_HF_TRUST_REMOTE_CODE_BOOL} = 'true' or adding {task_model} to {CONFIG_HF_TRUST_WHITELIST}"); + bail!("model {task_model} is not trusted to run remote code. Consider setting pgml.huggingface_trust_remote_code = 'true' or adding {task_model} to pgml.huggingface_trust_remote_code_whitelist"); } Ok(()) } -fn config_csv_list(name: &str) -> Vec { - match get_config(name) { +fn config_csv_list(csv_list: &GucSetting>) -> Vec { + match csv_list.get() { Some(value) => value + .to_string_lossy() .trim_matches('"') .split(',') .filter_map(|s| if s.is_empty() { None } else { Some(s.to_string()) }) @@ -122,7 +119,7 @@ mod tests { #[pg_test] fn test_empty_whitelist() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, "").unwrap(); + set_config("pgml.huggingface_whitelist", "").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); @@ -131,12 +128,12 @@ mod tests { #[pg_test] fn test_nonempty_whitelist() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, model).unwrap(); + set_config("pgml.huggingface_whitelist", model).unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); - set_config(CONFIG_HF_WHITELIST, "other_model").unwrap(); + set_config("pgml.huggingface_whitelist", "other_model").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_err()); @@ -145,8 +142,8 @@ mod tests { #[pg_test] fn test_trusted_model() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, model).unwrap(); - set_config(CONFIG_HF_TRUST_WHITELIST, model).unwrap(); + set_config("pgml.huggingface_whitelist", model).unwrap(); + set_config("pgml.huggingface_trust_remote_code_whitelist", model).unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); @@ -154,9 +151,9 @@ mod tests { let task_json = format!(json_template!(), model, true); let task: Value = serde_json::from_str(&task_json).unwrap(); - assert!(verify_task(&task).is_ok()); + assert!(verify_task(&task).is_err()); - set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap(); + set_config("pgml.huggingface_trust_remote_code", "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); @@ -169,8 +166,8 @@ mod tests { #[pg_test] fn test_untrusted_model() { let model = "Salesforce/xgen-7b-8k-inst"; - set_config(CONFIG_HF_WHITELIST, model).unwrap(); - set_config(CONFIG_HF_TRUST_WHITELIST, "other_model").unwrap(); + set_config("pgml.huggingface_whitelist", model).unwrap(); + set_config("pgml.huggingface_trust_remote_code_whitelist", "other_model").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); @@ -180,7 +177,7 @@ mod tests { let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_err()); - set_config(CONFIG_HF_TRUST_REMOTE_CODE_BOOL, "true").unwrap(); + set_config("pgml.huggingface_trust_remote_code", "true").unwrap(); let task_json = format!(json_template!(), model, false); let task: Value = serde_json::from_str(&task_json).unwrap(); assert!(verify_task(&task).is_ok()); diff --git a/pgml-extension/src/config.rs b/pgml-extension/src/config.rs index 8f9ade29a..424349ad0 100644 --- a/pgml-extension/src/config.rs +++ b/pgml-extension/src/config.rs @@ -1,16 +1,72 @@ +use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; use std::ffi::CStr; #[cfg(any(test, feature = "pg_test"))] use pgrx::{pg_schema, pg_test}; -use pgrx_pg_sys::AsPgCStr; - -pub fn get_config(name: &str) -> Option { - // SAFETY: name is not null because it is a Rust reference. - let ptr = unsafe { pgrx_pg_sys::GetConfigOption(name.as_pg_cstr(), true, false) }; - (!ptr.is_null()).then(move || { - // SAFETY: assuming pgrx_pg_sys is providing a valid, null terminated pointer. - unsafe { CStr::from_ptr(ptr) }.to_string_lossy().to_string() - }) + +pub static PGML_VENV: GucSetting> = GucSetting::>::new(None); +pub static PGML_HF_WHITELIST: GucSetting> = GucSetting::>::new(None); +pub static PGML_HF_TRUST_REMOTE_CODE: GucSetting = GucSetting::::new(false); +pub static PGML_HF_TRUST_REMOTE_CODE_WHITELIST: GucSetting> = + GucSetting::>::new(None); +pub static PGML_OMP_NUM_THREADS: GucSetting = GucSetting::::new(1); + +extern "C" { + fn omp_set_num_threads(num_threads: i32); +} + +pub fn initialize_server_params() { + GucRegistry::define_string_guc( + "pgml.venv", + "Python's virtual environment path", + "", + &PGML_VENV, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + "pgml.huggingface_whitelist", + "Models allowed to be downloaded from huggingface", + "", + &PGML_HF_WHITELIST, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_bool_guc( + "pgml.huggingface_trust_remote_code", + "Whether model can execute remote codes", + "", + &PGML_HF_TRUST_REMOTE_CODE, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + "pgml.huggingface_trust_remote_code_whitelist", + "Models allowed to execute remote codes when pgml.hugging_face_trust_remote_code = 'on'", + "", + &PGML_HF_TRUST_REMOTE_CODE_WHITELIST, + GucContext::Userset, + GucFlags::default(), + ); + + GucRegistry::define_int_guc( + "pgml.omp_num_threads", + "Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid", + "", + &PGML_OMP_NUM_THREADS, + 1, + i32::max_value(), + GucContext::Backend, + GucFlags::default(), + ); + + let omp_num_threads = PGML_OMP_NUM_THREADS.get(); + unsafe { + omp_set_num_threads(omp_num_threads); + } } #[cfg(any(test, feature = "pg_test"))] @@ -26,17 +82,17 @@ pub fn set_config(name: &str, value: &str) -> Result<(), pgrx::spi::Error> { mod tests { use super::*; - #[pg_test] - fn read_config_max_connections() { - let name = "max_connections"; - assert_eq!(get_config(name), Some("100".into())); - } - #[pg_test] fn read_pgml_huggingface_whitelist() { let name = "pgml.huggingface_whitelist"; let value = "meta-llama/Llama-2-7b"; set_config(name, value).unwrap(); - assert_eq!(get_config(name), Some(value.into())); + assert_eq!(PGML_HF_WHITELIST.get().unwrap().to_str().unwrap(), value); + } + + #[pg_test] + fn omp_num_threads_cannot_be_set_after_startup() { + let result = std::panic::catch_unwind(|| set_config("pgml.omp_num_threads", "1")); + assert!(result.is_err()); } } diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index 6c2884cee..1eab45ae7 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -24,6 +24,7 @@ extension_sql_file!("../sql/schema.sql", name = "schema"); #[cfg(not(feature = "use_as_lib"))] #[pg_guard] pub extern "C" fn _PG_init() { + config::initialize_server_params(); bindings::python::activate().expect("Error setting python venv"); orm::project::init(); } @@ -53,7 +54,7 @@ pub mod pg_test { pub fn postgresql_conf_options() -> Vec<&'static str> { // return any postgresql.conf settings that are required for your tests - let mut options = vec!["shared_preload_libraries = 'pgml'"]; + let mut options = vec!["shared_preload_libraries = 'pgml'", "pgml.omp_num_threads = '1'"]; if let Some(venv) = option_env!("PGML_VENV") { let option = format!("pgml.venv = '{venv}'"); options.push(Box::leak(option.into_boxed_str()));