Skip to content

Actually plugins #421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ serde_json = "1"

[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

42 changes: 39 additions & 3 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ admin_username = "admin_user"
# Password to access the virtual administrative database
admin_password = "admin_pass"

# Plugins!!
# query_router_plugins = ["pg_table_access", "intercept"]

# pool configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting.
# For a pool named `sharded_db`, clients access that pool using connection string like
Expand Down Expand Up @@ -157,6 +154,45 @@ connect_timeout = 3000
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30

[plugins]

[plugins.query_logger]
enabled = false

[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]

[plugins.intercept]
enabled = true

[plugins.intercept.queries.0]

query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]

[plugins.intercept.queries.1]

query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]

# User configs are structured as pool.<pool_name>.users.<user_index>
# This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
Expand Down
63 changes: 56 additions & 7 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ pub struct General {
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,

pub query_router_plugins: Option<Vec<String>>,
}

impl General {
Expand Down Expand Up @@ -404,7 +402,6 @@ impl Default for General {
auth_query_user: None,
auth_query_password: None,
server_lifetime: 1000 * 3600 * 24, // 24 hours,
query_router_plugins: None,
}
}
}
Expand Down Expand Up @@ -682,6 +679,55 @@ impl Default for Shard {
}
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct QueryLogger {
pub enabled: bool,
}

impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
query.substitute(db, user);
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
pub result: Vec<Vec<String>>,
}

impl Query {
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
}
}
}
}

/// Configuration wrapper.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Config {
Expand All @@ -700,6 +746,7 @@ pub struct Config {
pub path: String,

pub general: General,
pub plugins: Option<Plugins>,
pub pools: HashMap<String, Pool>,
}

Expand Down Expand Up @@ -737,6 +784,7 @@ impl Default for Config {
path: Self::default_path(),
general: General::default(),
pools: HashMap::default(),
plugins: None,
}
}
}
Expand Down Expand Up @@ -1128,25 +1176,26 @@ pub async fn parse(path: &str) -> Result<(), Error> {

pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config();

match parse(&old_config.path).await {
Ok(()) => (),
Err(err) => {
error!("Config reload error: {:?}", err);
return Err(Error::BadConfig);
}
};

let new_config = get_config();

match CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
};

if old_config.pools != new_config.pools {
info!("Pool configuration changed");
if old_config != new_config {
info!("Config changed, reloading");
ConnectionPool::from_config(client_server_map).await?;
Ok(true)
} else if old_config != new_config {
Ok(true)
} else {
Ok(false)
}
Expand Down
50 changes: 30 additions & 20 deletions src/plugins/intercept.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,41 @@ use serde_json::{json, Value};
use sqlparser::ast::Statement;
use std::collections::HashMap;

use log::debug;
use log::{debug, info};
use std::sync::Arc;

use crate::{
config::Intercept as InterceptConfig,
errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput},
pool::{PoolIdentifier, PoolMap},
query_router::QueryRouter,
};

pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, Value>>> =
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));

/// Configure the intercept plugin.
pub fn configure(pools: &PoolMap) {
/// Check if the interceptor plugin has been enabled.
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}

pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
let mut config = HashMap::new();
for (identifier, _) in pools.iter() {
// TODO: make this configurable from a text config.
let value = fool_datagrip(&identifier.db, &identifier.user);
config.insert(identifier.clone(), value);
let mut intercept_config = intercept_config.clone();
intercept_config.substitute(&identifier.db, &identifier.user);
config.insert(identifier.clone(), intercept_config);
}

CONFIG.store(Arc::new(config));

info!("Intercepting {} queries", intercept_config.queries.len());
}

pub fn disable() {
CONFIG.store(Arc::new(HashMap::new()));
}

// TODO: use these structs for deserialization
Expand Down Expand Up @@ -78,19 +89,19 @@ impl Plugin for Intercept {
// Normalization
let q = q.to_string().to_ascii_lowercase();

for target in query_map.as_array().unwrap().iter() {
if target["query"].as_str().unwrap() == q {
debug!("Query matched: {}", q);
for (_, target) in query_map.queries.iter() {
if target.query.as_str() == q {
debug!("Intercepting query: {}", q);

let rd = target["schema"]
.as_array()
.unwrap()
let rd = target
.schema
.iter()
.map(|row| {
let row = row.as_object().unwrap();
let name = &row[0];
let data_type = &row[1];
(
row["name"].as_str().unwrap(),
match row["data_type"].as_str().unwrap() {
name.as_str(),
match data_type.as_str() {
"text" => DataType::Text,
"anyarray" => DataType::AnyArray,
"oid" => DataType::Oid,
Expand All @@ -104,13 +115,11 @@ impl Plugin for Intercept {

result.put(row_description(&rd));

target["result"].as_array().unwrap().iter().for_each(|row| {
target.result.iter().for_each(|row| {
let row = row
.as_array()
.unwrap()
.iter()
.map(|s| {
let s = s.as_str().unwrap().to_string();
let s = s.as_str().to_string();

if s == "" {
None
Expand Down Expand Up @@ -141,6 +150,7 @@ impl Plugin for Intercept {

/// Make IntelliJ SQL plugin believe it's talking to an actual database
/// instead of PgCat.
#[allow(dead_code)]
fn fool_datagrip(database: &str, user: &str) -> Value {
json!([
{
Expand Down
9 changes: 6 additions & 3 deletions src/plugins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//!

pub mod intercept;
pub mod query_logger;
pub mod table_access;

use crate::{errors::Error, query_router::QueryRouter};
Expand All @@ -17,6 +18,7 @@ use bytes::BytesMut;
use sqlparser::ast::Statement;

pub use intercept::Intercept;
pub use query_logger::QueryLogger;
pub use table_access::TableAccess;

#[derive(Clone, Debug, PartialEq)]
Expand All @@ -29,12 +31,13 @@ pub enum PluginOutput {

#[async_trait]
pub trait Plugin {
// Custom output is allowed because we want to extend this system
// to rewriting queries some day. So an output of a plugin could be
// a rewritten AST.
// Run before the query is sent to the server.
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error>;

// TODO: run after the result is returned
// async fn callback(&mut self, query_router: &QueryRouter);
}
49 changes: 49 additions & 0 deletions src/plugins/query_logger.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//! Log all queries to stdout (or somewhere else, why not).

use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use arc_swap::ArcSwap;
use async_trait::async_trait;
use log::info;
use once_cell::sync::Lazy;
use sqlparser::ast::Statement;
use std::sync::Arc;

static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));

pub struct QueryLogger;

pub fn setup() {
ENABLED.store(Arc::new(true));

info!("Logging queries to stdout");
}

pub fn disable() {
ENABLED.store(Arc::new(false));
}

pub fn enabled() -> bool {
**ENABLED.load()
}

#[async_trait]
impl Plugin for QueryLogger {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("{}", query);

Ok(PluginOutput::Allow)
}
}
Loading