diff --git a/src/client.rs b/src/client.rs index 6c0d06fc..7d5e9798 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1260,7 +1260,7 @@ where // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode { + if self.transaction_mode && !server.in_copy_mode() { self.stats.idle(); break; @@ -1410,7 +1410,7 @@ where // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode { + if self.transaction_mode && !server.in_copy_mode() { break; } } diff --git a/src/server.rs b/src/server.rs index fa68b678..55444fba 100644 --- a/src/server.rs +++ b/src/server.rs @@ -170,6 +170,9 @@ pub struct Server { /// Is there more data for the client to read. data_available: bool, + /// Is the server in copy-in or copy-out modes + in_copy_mode: bool, + /// Is the server broken? We'll remote it from the pool if so. bad: bool, @@ -677,6 +680,7 @@ impl Server { process_id, secret_key, in_transaction: false, + in_copy_mode: false, data_available: false, bad: false, cleanup_state: CleanupState::new(), @@ -828,8 +832,19 @@ impl Server { break; } + // ErrorResponse + 'E' => { + if self.in_copy_mode { + self.in_copy_mode = false; + } + } + // CommandComplete 'C' => { + if self.in_copy_mode { + self.in_copy_mode = false; + } + let mut command_tag = String::new(); match message.reader().read_to_string(&mut command_tag) { Ok(_) => { @@ -873,10 +888,14 @@ impl Server { } // CopyInResponse: copy is starting from client to server. - 'G' => break, + 'G' => { + self.in_copy_mode = true; + break; + } // CopyOutResponse: copy is starting from the server to the client. 'H' => { + self.in_copy_mode = true; self.data_available = true; break; } @@ -1030,6 +1049,10 @@ impl Server { self.in_transaction } + pub fn in_copy_mode(&self) -> bool { + self.in_copy_mode + } + /// We don't buffer all of server responses, e.g. COPY OUT produces too much data. /// The client is responsible to call `self.recv()` while this method returns true. pub fn is_data_available(&self) -> bool { @@ -1129,6 +1152,10 @@ impl Server { self.cleanup_state.reset(); } + if self.in_copy_mode() { + warn!("Server returned while still in copy-mode"); + } + Ok(()) } diff --git a/tests/ruby/copy_spec.rb b/tests/ruby/copy_spec.rb new file mode 100644 index 00000000..5d3f2c02 --- /dev/null +++ b/tests/ruby/copy_spec.rb @@ -0,0 +1,102 @@ +# frozen_string_literal: true +require_relative 'spec_helper' + + +describe "COPY Handling" do + let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 5) } + before do + new_configs = processes.pgcat.current_config + + # Allow connections in the pool to expire faster + new_configs["general"]["idle_timeout"] = 5 + processes.pgcat.update_config(new_configs) + # We need to kill the old process that was using the default configs + processes.pgcat.stop + processes.pgcat.start + processes.pgcat.wait_until_ready + end + + before do + processes.all_databases.first.with_connection do |conn| + conn.async_exec "CREATE TABLE copy_test_table (a TEXT,b TEXT,c TEXT,d TEXT)" + end + end + + after do + processes.all_databases.first.with_connection do |conn| + conn.async_exec "DROP TABLE copy_test_table;" + end + end + + after do + processes.all_databases.map(&:reset) + processes.pgcat.shutdown + end + + describe "COPY FROM" do + context "within transaction" do + it "finishes within alloted time" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + Timeout.timeout(3) do + conn.async_exec("BEGIN") + conn.copy_data "COPY copy_test_table FROM STDIN CSV" do + sleep 0.5 + conn.put_copy_data "some,data,to,copy\n" + conn.put_copy_data "more,data,to,copy\n" + end + conn.async_exec("COMMIT") + end + + res = conn.async_exec("SELECT * FROM copy_test_table").to_a + expect(res).to eq([ + {"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"}, + {"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"} + ]) + end + end + + context "outside transaction" do + it "finishes within alloted time" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + Timeout.timeout(3) do + conn.copy_data "COPY copy_test_table FROM STDIN CSV" do + sleep 0.5 + conn.put_copy_data "some,data,to,copy\n" + conn.put_copy_data "more,data,to,copy\n" + end + end + + res = conn.async_exec("SELECT * FROM copy_test_table").to_a + expect(res).to eq([ + {"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"}, + {"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"} + ]) + end + end + end + + describe "COPY TO" do + before do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("BEGIN") + conn.copy_data "COPY copy_test_table FROM STDIN CSV" do + conn.put_copy_data "some,data,to,copy\n" + conn.put_copy_data "more,data,to,copy\n" + end + conn.async_exec("COMMIT") + conn.close + end + + it "works" do + res = [] + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.copy_data "COPY copy_test_table TO STDOUT CSV" do + while row=conn.get_copy_data + res << row + end + end + expect(res).to eq(["some,data,to,copy\n", "more,data,to,copy\n"]) + end + end + +end