diff --git a/Cargo.toml b/Cargo.toml index 80549821..1b3c8b1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,3 +48,4 @@ serde_json = "1" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" + diff --git a/pgcat.toml b/pgcat.toml index dfb57822..ce366329 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -77,9 +77,6 @@ admin_username = "admin_user" # Password to access the virtual administrative database admin_password = "admin_pass" -# Plugins!! -# query_router_plugins = ["pg_table_access", "intercept"] - # pool configs are structured as pool. # the pool_name is what clients use as database name when connecting. # For a pool named `sharded_db`, clients access that pool using connection string like @@ -157,6 +154,45 @@ connect_timeout = 3000 # Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). # dns_max_ttl = 30 +[plugins] + +[plugins.query_logger] +enabled = false + +[plugins.table_access] +enabled = false +tables = [ + "pg_user", + "pg_roles", + "pg_database", +] + +[plugins.intercept] +enabled = true + +[plugins.intercept.queries.0] + +query = "select current_database() as a, current_schemas(false) as b" +schema = [ + ["a", "text"], + ["b", "text"], +] +result = [ + ["${DATABASE}", "{public}"], +] + +[plugins.intercept.queries.1] + +query = "select current_database(), current_schema(), current_user" +schema = [ + ["current_database", "text"], + ["current_schema", "text"], + ["current_user", "text"], +] +result = [ + ["${DATABASE}", "public", "${USER}"], +] + # User configs are structured as pool..users. # This section holds the credentials for users that may connect to this cluster [pools.sharded_db.users.0] diff --git a/src/config.rs b/src/config.rs index aa98421c..cdf891ee 100644 --- a/src/config.rs +++ b/src/config.rs @@ -302,8 +302,6 @@ pub struct General { pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, - - pub query_router_plugins: Option>, } impl General { @@ -404,7 +402,6 @@ impl Default for General { auth_query_user: None, auth_query_password: None, server_lifetime: 1000 * 3600 * 24, // 24 hours, - query_router_plugins: None, } } } @@ -682,6 +679,55 @@ impl Default for Shard { } } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct Plugins { + pub intercept: Option, + pub table_access: Option, + pub query_logger: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct Intercept { + pub enabled: bool, + pub queries: BTreeMap, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct TableAccess { + pub enabled: bool, + pub tables: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct QueryLogger { + pub enabled: bool, +} + +impl Intercept { + pub fn substitute(&mut self, db: &str, user: &str) { + for (_, query) in self.queries.iter_mut() { + query.substitute(db, user); + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +pub struct Query { + pub query: String, + pub schema: Vec>, + pub result: Vec>, +} + +impl Query { + pub fn substitute(&mut self, db: &str, user: &str) { + for col in self.result.iter_mut() { + for i in 0..col.len() { + col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db); + } + } + } +} + /// Configuration wrapper. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Config { @@ -700,6 +746,7 @@ pub struct Config { pub path: String, pub general: General, + pub plugins: Option, pub pools: HashMap, } @@ -737,6 +784,7 @@ impl Default for Config { path: Self::default_path(), general: General::default(), pools: HashMap::default(), + plugins: None, } } } @@ -1128,6 +1176,7 @@ pub async fn parse(path: &str) -> Result<(), Error> { pub async fn reload_config(client_server_map: ClientServerMap) -> Result { let old_config = get_config(); + match parse(&old_config.path).await { Ok(()) => (), Err(err) => { @@ -1135,18 +1184,18 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result (), Err(err) => error!("DNS cache reinitialization error: {:?}", err), }; - if old_config.pools != new_config.pools { - info!("Pool configuration changed"); + if old_config != new_config { + info!("Config changed, reloading"); ConnectionPool::from_config(client_server_map).await?; Ok(true) - } else if old_config != new_config { - Ok(true) } else { Ok(false) } diff --git a/src/plugins/intercept.rs b/src/plugins/intercept.rs index 6e250dca..88d24d0e 100644 --- a/src/plugins/intercept.rs +++ b/src/plugins/intercept.rs @@ -11,10 +11,11 @@ use serde_json::{json, Value}; use sqlparser::ast::Statement; use std::collections::HashMap; -use log::debug; +use log::{debug, info}; use std::sync::Arc; use crate::{ + config::Intercept as InterceptConfig, errors::Error, messages::{command_complete, data_row_nullable, row_description, DataType}, plugins::{Plugin, PluginOutput}, @@ -22,19 +23,29 @@ use crate::{ query_router::QueryRouter, }; -pub static CONFIG: Lazy>> = +pub static CONFIG: Lazy>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::new())); -/// Configure the intercept plugin. -pub fn configure(pools: &PoolMap) { +/// Check if the interceptor plugin has been enabled. +pub fn enabled() -> bool { + !CONFIG.load().is_empty() +} + +pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) { let mut config = HashMap::new(); for (identifier, _) in pools.iter() { - // TODO: make this configurable from a text config. - let value = fool_datagrip(&identifier.db, &identifier.user); - config.insert(identifier.clone(), value); + let mut intercept_config = intercept_config.clone(); + intercept_config.substitute(&identifier.db, &identifier.user); + config.insert(identifier.clone(), intercept_config); } CONFIG.store(Arc::new(config)); + + info!("Intercepting {} queries", intercept_config.queries.len()); +} + +pub fn disable() { + CONFIG.store(Arc::new(HashMap::new())); } // TODO: use these structs for deserialization @@ -78,19 +89,19 @@ impl Plugin for Intercept { // Normalization let q = q.to_string().to_ascii_lowercase(); - for target in query_map.as_array().unwrap().iter() { - if target["query"].as_str().unwrap() == q { - debug!("Query matched: {}", q); + for (_, target) in query_map.queries.iter() { + if target.query.as_str() == q { + debug!("Intercepting query: {}", q); - let rd = target["schema"] - .as_array() - .unwrap() + let rd = target + .schema .iter() .map(|row| { - let row = row.as_object().unwrap(); + let name = &row[0]; + let data_type = &row[1]; ( - row["name"].as_str().unwrap(), - match row["data_type"].as_str().unwrap() { + name.as_str(), + match data_type.as_str() { "text" => DataType::Text, "anyarray" => DataType::AnyArray, "oid" => DataType::Oid, @@ -104,13 +115,11 @@ impl Plugin for Intercept { result.put(row_description(&rd)); - target["result"].as_array().unwrap().iter().for_each(|row| { + target.result.iter().for_each(|row| { let row = row - .as_array() - .unwrap() .iter() .map(|s| { - let s = s.as_str().unwrap().to_string(); + let s = s.as_str().to_string(); if s == "" { None @@ -141,6 +150,7 @@ impl Plugin for Intercept { /// Make IntelliJ SQL plugin believe it's talking to an actual database /// instead of PgCat. +#[allow(dead_code)] fn fool_datagrip(database: &str, user: &str) -> Value { json!([ { diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index 92fa70b7..6661ece6 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -9,6 +9,7 @@ //! pub mod intercept; +pub mod query_logger; pub mod table_access; use crate::{errors::Error, query_router::QueryRouter}; @@ -17,6 +18,7 @@ use bytes::BytesMut; use sqlparser::ast::Statement; pub use intercept::Intercept; +pub use query_logger::QueryLogger; pub use table_access::TableAccess; #[derive(Clone, Debug, PartialEq)] @@ -29,12 +31,13 @@ pub enum PluginOutput { #[async_trait] pub trait Plugin { - // Custom output is allowed because we want to extend this system - // to rewriting queries some day. So an output of a plugin could be - // a rewritten AST. + // Run before the query is sent to the server. async fn run( &mut self, query_router: &QueryRouter, ast: &Vec, ) -> Result; + + // TODO: run after the result is returned + // async fn callback(&mut self, query_router: &QueryRouter); } diff --git a/src/plugins/query_logger.rs b/src/plugins/query_logger.rs new file mode 100644 index 00000000..2dfda8bc --- /dev/null +++ b/src/plugins/query_logger.rs @@ -0,0 +1,49 @@ +//! Log all queries to stdout (or somewhere else, why not). + +use crate::{ + errors::Error, + plugins::{Plugin, PluginOutput}, + query_router::QueryRouter, +}; +use arc_swap::ArcSwap; +use async_trait::async_trait; +use log::info; +use once_cell::sync::Lazy; +use sqlparser::ast::Statement; +use std::sync::Arc; + +static ENABLED: Lazy> = Lazy::new(|| ArcSwap::from_pointee(false)); + +pub struct QueryLogger; + +pub fn setup() { + ENABLED.store(Arc::new(true)); + + info!("Logging queries to stdout"); +} + +pub fn disable() { + ENABLED.store(Arc::new(false)); +} + +pub fn enabled() -> bool { + **ENABLED.load() +} + +#[async_trait] +impl Plugin for QueryLogger { + async fn run( + &mut self, + _query_router: &QueryRouter, + ast: &Vec, + ) -> Result { + let query = ast + .iter() + .map(|q| q.to_string()) + .collect::>() + .join("; "); + info!("{}", query); + + Ok(PluginOutput::Allow) + } +} diff --git a/src/plugins/table_access.rs b/src/plugins/table_access.rs index 2e23278a..4613a4fb 100644 --- a/src/plugins/table_access.rs +++ b/src/plugins/table_access.rs @@ -5,17 +5,37 @@ use async_trait::async_trait; use sqlparser::ast::{visit_relations, Statement}; use crate::{ + config::TableAccess as TableAccessConfig, errors::Error, plugins::{Plugin, PluginOutput}, query_router::QueryRouter, }; +use log::{debug, info}; + +use arc_swap::ArcSwap; use core::ops::ControlFlow; +use once_cell::sync::Lazy; +use std::sync::Arc; + +static CONFIG: Lazy>> = Lazy::new(|| ArcSwap::from_pointee(vec![])); + +pub fn setup(config: &TableAccessConfig) { + CONFIG.store(Arc::new(config.tables.clone())); + + info!("Blocking access to {} tables", config.tables.len()); +} -pub struct TableAccess { - pub forbidden_tables: Vec, +pub fn enabled() -> bool { + !CONFIG.load().is_empty() } +pub fn disable() { + CONFIG.store(Arc::new(vec![])); +} + +pub struct TableAccess; + #[async_trait] impl Plugin for TableAccess { async fn run( @@ -24,13 +44,14 @@ impl Plugin for TableAccess { ast: &Vec, ) -> Result { let mut found = None; + let forbidden_tables = CONFIG.load(); visit_relations(ast, |relation| { let relation = relation.to_string(); let parts = relation.split(".").collect::>(); let table_name = parts.last().unwrap(); - if self.forbidden_tables.contains(&table_name.to_string()) { + if forbidden_tables.contains(&table_name.to_string()) { found = Some(table_name.to_string()); ControlFlow::<()>::Break(()) } else { @@ -39,6 +60,8 @@ impl Plugin for TableAccess { }); if let Some(found) = found { + debug!("Blocking access to table \"{}\"", found); + Ok(PluginOutput::Deny(format!( "permission for table \"{}\" denied", found diff --git a/src/pool.rs b/src/pool.rs index b986548a..2fd380ce 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -132,8 +132,6 @@ pub struct PoolSettings { pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, - - pub plugins: Option>, } impl Default for PoolSettings { @@ -158,7 +156,6 @@ impl Default for PoolSettings { auth_query: None, auth_query_user: None, auth_query_password: None, - plugins: None, } } } @@ -453,7 +450,6 @@ impl ConnectionPool { auth_query: pool_config.auth_query.clone(), auth_query_user: pool_config.auth_query_user.clone(), auth_query_password: pool_config.auth_query_password.clone(), - plugins: config.general.query_router_plugins.clone(), }, validated: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)), @@ -473,10 +469,29 @@ impl ConnectionPool { } } - // Initialize plugins here if required. - if let Some(plugins) = config.general.query_router_plugins { - if plugins.contains(&String::from("intercept")) { - crate::plugins::intercept::configure(&new_pools); + if let Some(ref plugins) = config.plugins { + if let Some(ref intercept) = plugins.intercept { + if intercept.enabled { + crate::plugins::intercept::setup(intercept, &new_pools); + } else { + crate::plugins::intercept::disable(); + } + } + + if let Some(ref table_access) = plugins.table_access { + if table_access.enabled { + crate::plugins::table_access::setup(table_access); + } else { + crate::plugins::table_access::disable(); + } + } + + if let Some(ref query_logger) = plugins.query_logger { + if query_logger.enabled { + crate::plugins::query_logger::setup(); + } else { + crate::plugins::query_logger::disable(); + } } } diff --git a/src/query_router.rs b/src/query_router.rs index 93bcd4f2..d995b804 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -15,7 +15,10 @@ use sqlparser::parser::Parser; use crate::config::Role; use crate::errors::Error; use crate::messages::BytesMutReader; -use crate::plugins::{Intercept, Plugin, PluginOutput, TableAccess}; +use crate::plugins::{ + intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger, + TableAccess, +}; use crate::pool::PoolSettings; use crate::sharding::Sharder; @@ -790,24 +793,26 @@ impl QueryRouter { /// Add your plugins here and execute them. pub async fn execute_plugins(&self, ast: &Vec) -> Result { - if let Some(plugins) = &self.pool_settings.plugins { - if plugins.contains(&String::from("intercept")) { - let mut intercept = Intercept {}; - let result = intercept.run(&self, ast).await; + if query_logger::enabled() { + let mut query_logger = QueryLogger {}; + let _ = query_logger.run(&self, ast).await; + } - if let Ok(PluginOutput::Intercept(output)) = result { - return Ok(PluginOutput::Intercept(output)); - } + if intercept::enabled() { + let mut intercept = Intercept {}; + let result = intercept.run(&self, ast).await; + + if let Ok(PluginOutput::Intercept(output)) = result { + return Ok(PluginOutput::Intercept(output)); } + } - if plugins.contains(&String::from("pg_table_access")) { - let mut table_access = TableAccess { - forbidden_tables: vec![String::from("pg_database"), String::from("pg_roles")], - }; + if table_access::enabled() { + let mut table_access = TableAccess {}; + let result = table_access.run(&self, ast).await; - if let Ok(PluginOutput::Deny(error)) = table_access.run(&self, ast).await { - return Ok(PluginOutput::Deny(error)); - } + if let Ok(PluginOutput::Deny(error)) = result { + return Ok(PluginOutput::Deny(error)); } } @@ -1156,7 +1161,6 @@ mod test { auth_query_password: None, auth_query_user: None, db: "test".to_string(), - plugins: None, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -1231,7 +1235,6 @@ mod test { auth_query_password: None, auth_query_user: None, db: "test".to_string(), - plugins: None, }; let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings.clone()); @@ -1376,13 +1379,17 @@ mod test { #[tokio::test] async fn test_table_access_plugin() { - QueryRouter::setup(); + use crate::config::TableAccess; + let ta = TableAccess { + enabled: true, + tables: vec![String::from("pg_database")], + }; - let mut qr = QueryRouter::new(); + crate::plugins::table_access::setup(&ta); - let mut pool_settings = PoolSettings::default(); - pool_settings.plugins = Some(vec![String::from("pg_table_access")]); - qr.update_pool_settings(pool_settings); + QueryRouter::setup(); + + let qr = QueryRouter::new(); let query = simple_query("SELECT * FROM pg_database"); let ast = QueryRouter::parse(&query).unwrap(); diff --git a/tests/ruby/admin_spec.rb b/tests/ruby/admin_spec.rb index fceb95bf..ea21630f 100644 --- a/tests/ruby/admin_spec.rb +++ b/tests/ruby/admin_spec.rb @@ -71,15 +71,17 @@ context "client connects but issues no queries" do it "only affects cl_idle stats" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"] connections = Array.new(20) { PG::connect(pgcat_conn_str) } sleep(1) - admin_conn = PG::connect(processes.pgcat.admin_connection_string) results = admin_conn.async_exec("SHOW POOLS")[0] %w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s| raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0" end expect(results["cl_idle"]).to eq("20") - expect(results["sv_idle"]).to eq("1") + expect(results["sv_idle"]).to eq(before_test) connections.map(&:close) sleep(1.1) @@ -87,7 +89,7 @@ %w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s| raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0" end - expect(results["sv_idle"]).to eq("1") + expect(results["sv_idle"]).to eq(before_test) end end diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index eb0cdaa9..ad4c32a4 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -27,7 +27,6 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod primary2 = PgInstance.new(8432, user["username"], user["password"], "shard2") pgcat_cfg = pgcat.current_config - pgcat_cfg["general"]["query_router_plugins"] = ["intercept"] pgcat_cfg["pools"] = { "#{pool_name}" => { "default_role" => "any",