Skip to content

Add Manual host banning to PgCat #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ where
let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect();

match query_parts[0].to_ascii_uppercase().as_str() {
"BAN" => {
trace!("BAN");
ban(stream, query_parts).await
}
"UNBAN" => {
trace!("UNBAN");
unban(stream, query_parts).await
}
"RELOAD" => {
trace!("RELOAD");
reload(stream, client_server_map).await
Expand All @@ -74,6 +82,10 @@ where
shutdown(stream).await
}
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
"BANS" => {
trace!("SHOW BANS");
show_bans(stream).await
}
"CONFIG" => {
trace!("SHOW CONFIG");
show_config(stream).await
Expand Down Expand Up @@ -350,6 +362,139 @@ where
custom_protocol_response_ok(stream, "SET").await
}

/// Bans a host from being used
async fn ban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let host = match tokens.get(1) {
Some(host) => host,
None => return error_response(stream, "BAN command requires a hostname to ban").await,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non blocking, can be done in a follow up: do we want to accept a duration string for how long to ban it for?

};

let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));

for (id, pool) in get_all_pools().iter() {
match pool.get_address_from_host(host) {
Some(address) => {
if !pool.is_banned(&address) {
pool.ban(&address, crate::errors::BanReason::ManualBan, -1);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host,
]));
}
}
None => {}
}
}

res.put(command_complete("BAN"));

// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
}

/// Clear a host for use
async fn unban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let host = match tokens.get(1) {
Some(host) => host,
None => return error_response(stream, "UNBAN command requires a hostname to unban").await,
};

let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));

for (id, pool) in get_all_pools().iter() {
match pool.get_address_from_host(host) {
Some(address) => {
if pool.is_banned(&address) {
pool.unban(&address);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host,
]));
}
}
None => {}
}
}

res.put(command_complete("UNBAN"));

// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
}

/// Shows all the bans
async fn show_bans<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
("reason", DataType::Text),
("ban_time", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));

for (id, pool) in get_all_pools().iter() {
pool.get_bans()
.iter()
.for_each(|(address, (ban_reason, ban_time))| {
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host.clone(),
format!("{:?}", ban_reason),
ban_time.to_string(),
]));
});
}

res.put(command_complete("SHOW BANS"));

// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
}

/// Reload the configuration file without restarting the process.
async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
where
Expand Down
13 changes: 8 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};

use crate::errors::{BanReason, Error};

use std::collections::HashMap;
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
Expand All @@ -11,7 +14,7 @@ use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::{get_config, Address, PoolMode};
use crate::constants::*;
use crate::errors::Error;

use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
Expand Down Expand Up @@ -1111,7 +1114,7 @@ where
match server.send(message).await {
Ok(_) => Ok(()),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageSendFailed, self.process_id);
Err(err)
}
}
Expand All @@ -1133,7 +1136,7 @@ where
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
Expand All @@ -1148,7 +1151,7 @@ where
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, self.process_id);
pool.ban(address, BanReason::StatementTimeout, self.process_id);
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
Expand All @@ -1157,7 +1160,7 @@ where
match server.recv().await {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
Expand Down
12 changes: 12 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub enum Error {
SocketError(String),
ClientBadStartup,
ProtocolSyncError(String),
BadQuery(String),
ServerError,
BadConfig,
AllServersDown,
Expand All @@ -15,3 +16,14 @@ pub enum Error {
ShuttingDown,
ParseBytesError(String),
}

/// Various errors.
#[derive(Debug, PartialEq, Clone)]
pub enum BanReason {
FailedHealthCheck,
MessageSendFailed,
MessageReceiveFailed,
FailedCheckout,
StatementTimeout,
ManualBan,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ManualBan,
AdminBan,

}
46 changes: 37 additions & 9 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::time::Instant;
use tokio::sync::Notify;

use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
use crate::errors::Error;
use crate::errors::{BanReason, Error};

use crate::server::Server;
use crate::sharding::ShardingFunction;
Expand All @@ -29,7 +29,7 @@ pub type SecretKey = i32;
pub type ServerHost = String;
pub type ServerPort = u16;

pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type BanList = Arc<RwLock<Vec<HashMap<Address, (BanReason, NaiveDateTime)>>>>;
pub type ClientServerMap =
Arc<Mutex<HashMap<(ProcessId, SecretKey), (ProcessId, SecretKey, ServerHost, ServerPort)>>>;
pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
Expand Down Expand Up @@ -489,7 +489,7 @@ impl ConnectionPool {
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, client_process_id);
self.ban(address, BanReason::FailedCheckout, client_process_id);
self.stats
.client_checkout_error(client_process_id, address.id);
continue;
Expand Down Expand Up @@ -582,14 +582,14 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();

self.ban(&address, client_process_id);
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
return false;
}

/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, client_id: i32) {
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
// Primary can never be banned
if address.role == Role::Primary {
return;
Expand All @@ -599,12 +599,12 @@ impl ConnectionPool {
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
self.stats.client_ban_error(client_id, address.id);
guard[address.shard].insert(address.clone(), now);
guard[address.shard].insert(address.clone(), (reason, now));
}

/// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions.
pub fn _unban(&self, address: &Address) {
pub fn unban(&self, address: &Address) {
let mut guard = self.banlist.write();
guard[address.shard].remove(address);
}
Expand Down Expand Up @@ -653,9 +653,13 @@ impl ConnectionPool {
// Check if ban time is expired
let read_guard = self.banlist.read();
let exceeded_ban_time = match read_guard[address.shard].get(address) {
Some(timestamp) => {
Some((ban_reason, timestamp)) => {
let now = chrono::offset::Utc::now().naive_utc();
now.timestamp() - timestamp.timestamp() > self.settings.ban_time
if ban_reason == &BanReason::ManualBan {
now.timestamp() - timestamp.timestamp() > 60 * 60
} else {
now.timestamp() - timestamp.timestamp() > self.settings.ban_time
}
}
None => return true,
};
Expand All @@ -679,6 +683,30 @@ impl ConnectionPool {
self.databases.len()
}

pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
let guard = self.banlist.read();
for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
}
}
return bans;
}

/// Get the address from the host url
pub fn get_address_from_host(&self, host: &str) -> Option<Address> {
for shard in 0..self.shards() {
for server in 0..self.servers(shard) {
let address = self.address(shard, server);
if address.host == host {
return Some(address.clone());
}
}
}
None
}

/// Get the number of servers (primary and replicas)
/// configured for a shard.
pub fn servers(&self, shard: usize) -> usize {
Expand Down
Loading