Skip to content

Server TLS #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ nix = "0.26.2"
atomic_enum = "0.2.0"
postgres-protocol = "0.6.5"
fallible-iterator = "0.2"
pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }

[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"
10 changes: 8 additions & 2 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5

# Path to TLS Certificate file to use for TLS connections
# tls_certificate = "server.cert"
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key"
# tls_private_key = ".circleci/server.key"

# Enable/disable server TLS
server_tls = false

# Verify server certificate is completely authentic.
verify_server_certificate = false

# User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
Expand Down
13 changes: 12 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ where
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into()));
}

Expand All @@ -565,6 +566,8 @@ where
}

Err(err) => {
wrong_password(&mut write, username).await?;

return Err(Error::ClientAuthPassthroughError(
err.to_string(),
client_identifier,
Expand All @@ -587,7 +590,15 @@ where
client_identifier
);

let fetched_hash = refetch_auth_hash(&pool).await?;
let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;

return Err(err);
}
};

let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);

// Ok password changed in server an auth is possible.
Expand Down
14 changes: 14 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ pub struct General {

pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>,

#[serde(default)] // false
pub server_tls: bool,

#[serde(default)] // false
pub verify_server_certificate: bool,

pub admin_username: String,
pub admin_password: String,

Expand Down Expand Up @@ -373,6 +380,8 @@ impl Default for General {
autoreload: None,
tls_certificate: None,
tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
auth_query: None,
Expand Down Expand Up @@ -852,6 +861,11 @@ impl Config {
info!("TLS support is disabled");
}
};
info!("Server TLS enabled: {}", self.general.server_tls);
info!(
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);

for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?)
Expand Down
43 changes: 42 additions & 1 deletion src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ where

/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25);

bytes.put_i32(196608); // Protocol number
Expand Down Expand Up @@ -150,6 +153,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
}
}

pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);

bytes.put_i32(8);
bytes.put_i32(80877103);

match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}

/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
Expand Down Expand Up @@ -505,6 +523,29 @@ where
}
}

pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}

/// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where
Expand Down
3 changes: 1 addition & 2 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,7 @@ impl ConnectionPool {
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
.await?;

pools.push(pool);
servers.push(address);
Expand Down
Loading