Skip to content

Async PG support #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .circleci/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ python3 tests/python/tests.py || exit 1

start_pgcat "info"

python3 tests/python/async_test.py

start_pgcat "info"

# Admin tests
export PGPASSWORD=admin_pass
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
Expand Down
50 changes: 40 additions & 10 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ where
}

// Grab a server from the pool.
let connection = match pool
let mut connection = match pool
.get(query_router.shard(), query_router.role(), &self.stats)
.await
{
Expand Down Expand Up @@ -975,9 +975,8 @@ where
}
};

let mut reference = connection.0;
let server = &mut *connection.0;
let address = connection.1;
let server = &mut *reference;

// Server is assigned to the client in case the client wants to
// cancel a query later.
Expand All @@ -1000,6 +999,7 @@ where

// Set application_name.
server.set_name(&self.application_name).await?;
server.switch_async(false);

let mut initial_message = Some(message);

Expand All @@ -1019,12 +1019,37 @@ where
None => {
trace!("Waiting for message inside transaction or in session mode");

match tokio::time::timeout(
idle_client_timeout_duration,
read_message(&mut self.read),
)
.await
{
let message = tokio::select! {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this idea trying to see if this helps with #303

Something else it's missing that doesn't allow to send the server messages to the client. Maybe this 2 issues are related and could be solved both with the same fix

message = tokio::time::timeout(
idle_client_timeout_duration,
read_message(&mut self.read),
) => message,

server_message = server.recv() => {
debug!("Got async message");

let server_message = match server_message {
Ok(message) => message,
Err(err) => {
pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats));
server.mark_bad();
return Err(err);
}
};

match write_all_half(&mut self.write, &server_message).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};

continue;
}
};

match message {
Ok(Ok(message)) => message,
Ok(Err(err)) => {
// Client disconnected inside a transaction.
Expand Down Expand Up @@ -1141,9 +1166,14 @@ where

// Sync
// Frontend (client) is asking for the query result now.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to update these docs to include Flush. Something like:

                    // Sync (S)
                    // Frontend (client) is asking for the query result now.
                    // ... or ...
                    // Flush (H)
                    // Frontend (client) is asking for server to send query results now without
                    // sync.

'S' => {
'S' | 'H' => {
debug!("Sending query to server");

if code == 'H' {
server.switch_async(true);
debug!("Client requested flush, going async");
}

self.buffer.put(&message[..]);

let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
Expand Down
6 changes: 5 additions & 1 deletion src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ impl ConnectionPool {
self.databases.len()
}

/// Retrieve all bans for all servers.
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
let guard = self.banlist.read();
Expand All @@ -788,7 +789,7 @@ impl ConnectionPool {
return bans;
}

/// Get the address from the host url
/// Get the address from the host url.
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
let mut addresses = Vec::new();
for shard in 0..self.shards() {
Expand Down Expand Up @@ -827,10 +828,13 @@ impl ConnectionPool {
&self.addresses[shard][server]
}

/// Get server settings retrieved at connection setup.
pub fn server_info(&self) -> BytesMut {
self.server_info.read().clone()
}

/// Calculate how many used connections in the pool
/// for the given server.
fn busy_connection_count(&self, address: &Address) -> u32 {
let state = self.pool_state(address.shard, address.address_index);
let idle = state.idle_connections;
Expand Down
26 changes: 24 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct Server {

/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,
is_async: bool,

/// Server information the server sent us over on startup.
server_info: BytesMut,
Expand Down Expand Up @@ -450,6 +451,7 @@ impl Server {
read: BufReader::new(read),
write,
buffer: BytesMut::with_capacity(8196),
is_async: false,
server_info,
process_id,
secret_key,
Expand Down Expand Up @@ -537,6 +539,16 @@ impl Server {
}
}

/// Switch to async mode, flushing messages as soon
/// as we receive them without buffering or waiting for "ReadyForQuery".
pub fn switch_async(&mut self, on: bool) {
if on {
self.is_async = true;
} else {
self.is_async = false;
}
}

/// Receive data from the server in response to a client request.
/// This method must be called multiple times while `self.is_data_available()` is true
/// in order to receive all data the server has to offer.
Expand Down Expand Up @@ -632,7 +644,10 @@ impl Server {
// DataRow
'D' => {
// More data is available after this message, this is not the end of the reply.
self.data_available = true;
// If we're async, flush to client now.
if !self.is_async {
self.data_available = true;
}

// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
if self.buffer.len() >= 8196 {
Expand All @@ -645,7 +660,10 @@ impl Server {

// CopyOutResponse: copy is starting from the server to the client.
'H' => {
self.data_available = true;
// If we're in async mode, flush now.
if !self.is_async {
self.data_available = true;
}
break;
}

Expand All @@ -665,6 +683,10 @@ impl Server {
// Keep buffering until ReadyForQuery shows up.
_ => (),
};

if self.is_async {
break;
}
}

let bytes = self.buffer.clone();
Expand Down
60 changes: 60 additions & 0 deletions tests/python/async_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import psycopg2
import asyncio
import asyncpg

PGCAT_HOST = "127.0.0.1"
PGCAT_PORT = "6432"


def regular_main():
# Connect to the PostgreSQL database
conn = psycopg2.connect(
host=PGCAT_HOST,
database="sharded_db",
user="sharding_user",
password="sharding_user",
port=PGCAT_PORT,
)

# Open a cursor to perform database operations
cur = conn.cursor()

# Execute a SQL query
cur.execute("SELECT 1")

# Fetch the results
rows = cur.fetchall()

# Print the results
for row in rows:
print(row[0])

# Close the cursor and the database connection
cur.close()
conn.close()


async def main():
# Connect to the PostgreSQL database
conn = await asyncpg.connect(
host=PGCAT_HOST,
database="sharded_db",
user="sharding_user",
password="sharding_user",
port=PGCAT_PORT,
)

# Execute a SQL query
for _ in range(25):
rows = await conn.fetch("SELECT 1")

# Print the results
for row in rows:
print(row[0])

# Close the database connection
await conn.close()


regular_main()
asyncio.run(main())
11 changes: 10 additions & 1 deletion tests/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
asyncio==3.4.3
asyncpg==0.27.0
black==23.3.0
click==8.1.3
mypy-extensions==1.0.0
packaging==23.1
pathspec==0.11.1
platformdirs==3.2.0
psutil==5.9.1
psycopg2==2.9.3
psutil==5.9.1
tomli==2.0.1