diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java index 83f0f7a862..efc291933b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/InternalAsyncSession.java @@ -146,7 +146,7 @@ private void executeWork(CompletableFuture resultFuture, UnmanagedTransac Throwable error = Futures.completionExceptionCause( completionError ); if ( error != null ) { - rollbackTxAfterFailedTransactionWork( tx, resultFuture, error ); + closeTxAfterFailedTransactionWork( tx, resultFuture, error ); } else { @@ -174,43 +174,33 @@ private CompletionStage safeExecuteWork(UnmanagedTransaction tx, AsyncTra } } - private void rollbackTxAfterFailedTransactionWork(UnmanagedTransaction tx, CompletableFuture resultFuture, Throwable error ) + private void closeTxAfterFailedTransactionWork( UnmanagedTransaction tx, CompletableFuture resultFuture, Throwable error ) { - if ( tx.isOpen() ) - { - tx.rollbackAsync().whenComplete( ( ignore, rollbackError ) -> { - if ( rollbackError != null ) + tx.closeAsync().whenComplete( + ( ignored, rollbackError ) -> { - error.addSuppressed( rollbackError ); - } - resultFuture.completeExceptionally( error ); - } ); - } - else - { - resultFuture.completeExceptionally( error ); - } + if ( rollbackError != null ) + { + error.addSuppressed( rollbackError ); + } + resultFuture.completeExceptionally( error ); + } ); } private void closeTxAfterSucceededTransactionWork(UnmanagedTransaction tx, CompletableFuture resultFuture, T result ) { - if ( tx.isOpen() ) - { - tx.commitAsync().whenComplete( ( ignore, completionError ) -> { - Throwable commitError = Futures.completionExceptionCause( completionError ); - if ( commitError != null ) + tx.closeAsync( true ).whenComplete( + ( ignored, completionError ) -> { - resultFuture.completeExceptionally( commitError ); - } - else - { - resultFuture.complete( result ); - } - } ); - } - else - { - resultFuture.complete( result ); - } + Throwable commitError = Futures.completionExceptionCause( completionError ); + if ( commitError != null ) + { + resultFuture.completeExceptionally( commitError ); + } + else + { + resultFuture.complete( result ); + } + } ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index a46b4d628a..fbd7a985c7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -20,9 +20,13 @@ import java.util.Arrays; import java.util.EnumSet; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.BiFunction; +import java.util.function.Function; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Query; @@ -37,91 +41,57 @@ import org.neo4j.driver.internal.cursor.RxResultCursor; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.spi.Connection; -import org.neo4j.driver.internal.util.Futures; +import static org.neo4j.driver.internal.util.Futures.asCompletionException; +import static org.neo4j.driver.internal.util.Futures.combineErrors; import static org.neo4j.driver.internal.util.Futures.completedWithNull; import static org.neo4j.driver.internal.util.Futures.failedFuture; +import static org.neo4j.driver.internal.util.Futures.futureCompletingConsumer; +import static org.neo4j.driver.internal.util.LockUtil.executeWithLock; public class UnmanagedTransaction { private enum State { - /** The transaction is running with no explicit success or failure marked */ + /** + * The transaction is running with no explicit success or failure marked + */ ACTIVE, /** - * This transaction has been terminated either because of explicit {@link Session#reset()} or because of a - * fatal connection error. + * This transaction has been terminated either because of explicit {@link Session#reset()} or because of a fatal connection error. */ TERMINATED, - /** This transaction has successfully committed */ - COMMITTED, - - /** This transaction has been rolled back */ - ROLLED_BACK - } - - /** - * This is a holder so that we can have ony the state volatile in the tx without having to synchronize the whole block. - */ - private static final class StateHolder - { - private static final EnumSet OPEN_STATES = EnumSet.of( State.ACTIVE, State.TERMINATED ); - private static final StateHolder ACTIVE_HOLDER = new StateHolder( State.ACTIVE, null ); - private static final StateHolder COMMITTED_HOLDER = new StateHolder( State.COMMITTED, null ); - private static final StateHolder ROLLED_BACK_HOLDER = new StateHolder( State.ROLLED_BACK, null ); - /** - * The actual state. + * This transaction has successfully committed */ - final State value; + COMMITTED, /** - * If this holder contains a state of {@link State#TERMINATED}, this represents the cause if any. + * This transaction has been rolled back */ - final Throwable causeOfTermination; - - static StateHolder of( State value ) - { - switch ( value ) - { - case ACTIVE: - return ACTIVE_HOLDER; - case COMMITTED: - return COMMITTED_HOLDER; - case ROLLED_BACK: - return ROLLED_BACK_HOLDER; - case TERMINATED: - default: - throw new IllegalArgumentException( "Cannot provide a default state holder for state " + value ); - } - } - - static StateHolder terminatedWith( Throwable cause ) - { - return new StateHolder( State.TERMINATED, cause ); - } - - private StateHolder( State value, Throwable causeOfTermination ) - { - this.value = value; - this.causeOfTermination = causeOfTermination; - } - - boolean isOpen() - { - return OPEN_STATES.contains( this.value ); - } + ROLLED_BACK } + protected static final String CANT_COMMIT_COMMITTED_MSG = "Can't commit, transaction has been committed"; + protected static final String CANT_ROLLBACK_COMMITTED_MSG = "Can't rollback, transaction has been committed"; + protected static final String CANT_COMMIT_ROLLED_BACK_MSG = "Can't commit, transaction has been rolled back"; + protected static final String CANT_ROLLBACK_ROLLED_BACK_MSG = "Can't rollback, transaction has been rolled back"; + protected static final String CANT_COMMIT_ROLLING_BACK_MSG = "Can't commit, transaction has been requested to be rolled back"; + protected static final String CANT_ROLLBACK_COMMITTING_MSG = "Can't rollback, transaction has been requested to be committed"; + private static final EnumSet OPEN_STATES = EnumSet.of( State.ACTIVE, State.TERMINATED ); + private final Connection connection; private final BoltProtocol protocol; private final BookmarkHolder bookmarkHolder; private final ResultCursorsHolder resultCursors; private final long fetchSize; - - private volatile StateHolder state = StateHolder.of( State.ACTIVE ); + private final Lock lock = new ReentrantLock(); + private State state = State.ACTIVE; + private CompletableFuture commitFuture; + private CompletableFuture rollbackFuture; + private Throwable causeOfTermination; public UnmanagedTransaction( Connection connection, BookmarkHolder bookmarkHolder, long fetchSize ) { @@ -156,7 +126,7 @@ else if ( beginError instanceof ConnectionReadTimeoutException ) { connection.release(); } - throw Futures.asCompletionException( beginError ); + throw asCompletionException( beginError ); } return this; } ); @@ -164,50 +134,22 @@ else if ( beginError instanceof ConnectionReadTimeoutException ) public CompletionStage closeAsync() { - if ( isOpen() ) - { - return rollbackAsync(); - } - else - { - return completedWithNull(); - } + return closeAsync( false ); + } + + public CompletionStage closeAsync( boolean commit ) + { + return closeAsync( commit, true ); } public CompletionStage commitAsync() { - if ( state.value == State.COMMITTED ) - { - return failedFuture( new ClientException( "Can't commit, transaction has been committed" ) ); - } - else if ( state.value == State.ROLLED_BACK ) - { - return failedFuture( new ClientException( "Can't commit, transaction has been rolled back" ) ); - } - else - { - return resultCursors.retrieveNotConsumedError() - .thenCompose( error -> doCommitAsync( error ).handle( handleCommitOrRollback( error ) ) ) - .whenComplete( ( ignore, error ) -> handleTransactionCompletion( true, error ) ); - } + return closeAsync( true, false ); } public CompletionStage rollbackAsync() { - if ( state.value == State.COMMITTED ) - { - return failedFuture( new ClientException( "Can't rollback, transaction has been committed" ) ); - } - else if ( state.value == State.ROLLED_BACK ) - { - return failedFuture( new ClientException( "Can't rollback, transaction has been rolled back" ) ); - } - else - { - return resultCursors.retrieveNotConsumedError() - .thenCompose( error -> doRollbackAsync().handle( handleCommitOrRollback( error ) ) ) - .whenComplete( ( ignore, error ) -> handleTransactionCompletion( false, error ) ); - } + return closeAsync( false, false ); } public CompletionStage runAsync( Query query ) @@ -219,7 +161,7 @@ public CompletionStage runAsync( Query query ) return cursorStage.thenCompose( AsyncResultCursor::mapSuccessfulRunCompletionAsync ).thenApply( cursor -> cursor ); } - public CompletionStage runRx(Query query) + public CompletionStage runRx( Query query ) { ensureCanRunQueries(); CompletionStage cursorStage = @@ -230,22 +172,26 @@ public CompletionStage runRx(Query query) public boolean isOpen() { - return state.isOpen(); + return OPEN_STATES.contains( executeWithLock( lock, () -> state ) ); } public void markTerminated( Throwable cause ) { - if ( state.value == State.TERMINATED ) + executeWithLock( lock, () -> { - if ( state.causeOfTermination != null ) + if ( state == State.TERMINATED ) { - addSuppressedWhenNotCaptured( state.causeOfTermination, cause ); + if ( causeOfTermination != null ) + { + addSuppressedWhenNotCaptured( causeOfTermination, cause ); + } } - } - else - { - state = StateHolder.terminatedWith( cause ); - } + else + { + state = State.TERMINATED; + causeOfTermination = cause; + } + } ); } private void addSuppressedWhenNotCaptured( Throwable currentCause, Throwable newCause ) @@ -267,46 +213,46 @@ public Connection connection() private void ensureCanRunQueries() { - if ( state.value == State.COMMITTED ) - { - throw new ClientException( "Cannot run more queries in this transaction, it has been committed" ); - } - else if ( state.value == State.ROLLED_BACK ) + executeWithLock( lock, () -> { - throw new ClientException( "Cannot run more queries in this transaction, it has been rolled back" ); - } - else if ( state.value == State.TERMINATED ) - { - throw new ClientException( "Cannot run more queries in this transaction, " + - "it has either experienced an fatal error or was explicitly terminated", state.causeOfTermination ); - } + if ( state == State.COMMITTED ) + { + throw new ClientException( "Cannot run more queries in this transaction, it has been committed" ); + } + else if ( state == State.ROLLED_BACK ) + { + throw new ClientException( "Cannot run more queries in this transaction, it has been rolled back" ); + } + else if ( state == State.TERMINATED ) + { + throw new ClientException( "Cannot run more queries in this transaction, " + + "it has either experienced an fatal error or was explicitly terminated", causeOfTermination ); + } + } ); } private CompletionStage doCommitAsync( Throwable cursorFailure ) { - if ( state.value == State.TERMINATED ) - { - return failedFuture( new ClientException( "Transaction can't be committed. " + - "It has been rolled back either because of an error or explicit termination", - cursorFailure != state.causeOfTermination ? state.causeOfTermination : null ) ); - } - return protocol.commitTransaction( connection ).thenAccept( bookmarkHolder::setBookmark ); + ClientException exception = executeWithLock( + lock, () -> state == State.TERMINATED + ? new ClientException( "Transaction can't be committed. " + + "It has been rolled back either because of an error or explicit termination", + cursorFailure != causeOfTermination ? causeOfTermination : null ) + : null + ); + return exception != null ? failedFuture( exception ) : protocol.commitTransaction( connection ).thenAccept( bookmarkHolder::setBookmark ); } private CompletionStage doRollbackAsync() { - if ( state.value == State.TERMINATED ) - { - return completedWithNull(); - } - return protocol.rollbackTransaction( connection ); + return executeWithLock( lock, () -> state ) == State.TERMINATED ? completedWithNull() : protocol.rollbackTransaction( connection ); } private static BiFunction handleCommitOrRollback( Throwable cursorFailure ) { return ( ignore, commitOrRollbackError ) -> { - CompletionException combinedError = Futures.combineErrors( cursorFailure, commitOrRollbackError ); + CompletionException combinedError = combineErrors( cursorFailure, commitOrRollbackError ); if ( combinedError != null ) { throw combinedError; @@ -315,17 +261,19 @@ private static BiFunction handleCommitOrRollback( Throwable }; } - private void handleTransactionCompletion( boolean commitOnSuccess, Throwable throwable ) + private void handleTransactionCompletion( boolean commitAttempt, Throwable throwable ) { - if ( commitOnSuccess && throwable == null ) - { - state = StateHolder.of( State.COMMITTED ); - } - else + executeWithLock( lock, () -> { - state = StateHolder.of( State.ROLLED_BACK ); - } - + if ( commitAttempt && throwable == null ) + { + state = State.COMMITTED; + } + else + { + state = State.ROLLED_BACK; + } + } ); if ( throwable instanceof AuthorizationExpiredException ) { connection.terminateAndRelease( AuthorizationExpiredException.DESCRIPTION ); @@ -339,4 +287,81 @@ else if ( throwable instanceof ConnectionReadTimeoutException ) connection.release(); // release in background } } + + private CompletionStage closeAsync( boolean commit, boolean completeWithNullIfNotOpen ) + { + CompletionStage stage = executeWithLock( lock, () -> + { + CompletionStage resultStage = null; + if ( completeWithNullIfNotOpen && !isOpen() ) + { + resultStage = completedWithNull(); + } + else if ( state == State.COMMITTED ) + { + resultStage = failedFuture( new ClientException( commit ? CANT_COMMIT_COMMITTED_MSG : CANT_ROLLBACK_COMMITTED_MSG ) ); + } + else if ( state == State.ROLLED_BACK ) + { + resultStage = failedFuture( new ClientException( commit ? CANT_COMMIT_ROLLED_BACK_MSG : CANT_ROLLBACK_ROLLED_BACK_MSG ) ); + } + else + { + if ( commit ) + { + if ( rollbackFuture != null ) + { + resultStage = failedFuture( new ClientException( CANT_COMMIT_ROLLING_BACK_MSG ) ); + } + else if ( commitFuture != null ) + { + resultStage = commitFuture; + } + else + { + commitFuture = new CompletableFuture<>(); + } + } + else + { + if ( commitFuture != null ) + { + resultStage = failedFuture( new ClientException( CANT_ROLLBACK_COMMITTING_MSG ) ); + } + else if ( rollbackFuture != null ) + { + resultStage = rollbackFuture; + } + else + { + rollbackFuture = new CompletableFuture<>(); + } + } + } + return resultStage; + } ); + + if ( stage == null ) + { + CompletableFuture targetFuture; + Function> targetAction; + if ( commit ) + { + targetFuture = commitFuture; + targetAction = throwable -> doCommitAsync( throwable ).handle( handleCommitOrRollback( throwable ) ); + } + else + { + targetFuture = rollbackFuture; + targetAction = throwable -> doRollbackAsync().handle( handleCommitOrRollback( throwable ) ); + } + resultCursors.retrieveNotConsumedError() + .thenCompose( targetAction ) + .whenComplete( ( ignored, throwable ) -> handleTransactionCompletion( commit, throwable ) ) + .whenComplete( futureCompletingConsumer( targetFuture ) ); + stage = targetFuture; + } + + return stage; + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java index 110892c026..222b64562d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java @@ -130,7 +130,7 @@ public Publisher writeTransaction( RxTransactionWork Publisher runTransaction( AccessMode mode, RxTransactionWork> work, TransactionConfig config ) { Flux repeatableWork = Flux.usingWhen( beginTransaction( mode, config ), work::execute, - InternalRxTransaction::commitIfOpen, ( tx, error ) -> tx.close(), null ); + tx -> tx.close( true ), ( tx, error ) -> tx.close(), InternalRxTransaction::close ); return session.retryLogic().retryRx( repeatableWork ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java index c1a9267336..b4212ae963 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxTransaction.java @@ -30,7 +30,6 @@ import org.neo4j.driver.reactive.RxTransaction; import static org.neo4j.driver.internal.reactive.RxUtils.createEmptyPublisher; -import static org.neo4j.driver.internal.util.Futures.completedWithNull; public class InternalRxTransaction extends AbstractRxQueryRunner implements RxTransaction { @@ -77,13 +76,13 @@ public Publisher rollback() return createEmptyPublisher( tx::rollbackAsync ); } - Publisher commitIfOpen() + Publisher close() { - return createEmptyPublisher( () -> tx.isOpen() ? tx.commitAsync() : completedWithNull() ); + return close( false ); } - Publisher close() + Publisher close( boolean commit ) { - return createEmptyPublisher( tx::closeAsync ); + return createEmptyPublisher( () -> tx.closeAsync( commit ) ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java index 24ed13c879..56b714df1c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java @@ -256,6 +256,21 @@ public static CompletableFuture onErrorContinue( CompletableFuture fut } ); } + public static BiConsumer futureCompletingConsumer( CompletableFuture future ) + { + return ( value, throwable ) -> + { + if ( throwable != null ) + { + future.completeExceptionally( throwable ); + } + else + { + future.complete( value ); + } + }; + } + private static class CompletionResult { T value; diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index b8909e176f..f639565a66 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -19,9 +19,17 @@ package org.neo4j.driver.internal.async; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.InOrder; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Stream; import org.neo4j.driver.Bookmark; import org.neo4j.driver.Query; @@ -32,6 +40,7 @@ import org.neo4j.driver.internal.DefaultBookmarkHolder; import org.neo4j.driver.internal.FailableCursor; import org.neo4j.driver.internal.InternalBookmark; +import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ResponseHandler; @@ -40,16 +49,21 @@ import static java.util.concurrent.CompletableFuture.completedFuture; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.neo4j.driver.internal.handlers.pulln.FetchSizeUtil.UNLIMITED_FETCH_SIZE; import static org.neo4j.driver.util.TestUtil.assertNoCircularReferences; @@ -311,6 +325,133 @@ void shouldReleaseConnectionOnConnectionReadTimeoutExceptionFailure() verify( connection, never() ).release(); } + private static Stream similarTransactionCompletingActionArgs() + { + return Stream.of( + Arguments.of( true, "commit", "commit" ), + + Arguments.of( false, "rollback", "rollback" ), + Arguments.of( false, "rollback", "close" ), + + Arguments.of( false, "close", "rollback" ), + Arguments.of( false, "close", "close" ) + ); + } + + @ParameterizedTest + @MethodSource( "similarTransactionCompletingActionArgs" ) + void shouldReturnExistingStageOnSimilarCompletingAction( boolean protocolCommit, String initialAction, String similarAction ) + { + Connection connection = mock( Connection.class ); + BoltProtocol protocol = mock( BoltProtocol.class ); + given( connection.protocol() ).willReturn( protocol ); + given( protocolCommit ? protocol.commitTransaction( connection ) : protocol.rollbackTransaction( connection ) ).willReturn( new CompletableFuture<>() ); + UnmanagedTransaction tx = new UnmanagedTransaction( connection, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE ); + + CompletionStage initialStage = mapTransactionAction( initialAction, tx ).get(); + CompletionStage similarStage = mapTransactionAction( similarAction, tx ).get(); + + assertSame( initialStage, similarStage ); + if ( protocolCommit ) + { + then( protocol ).should( times( 1 ) ).commitTransaction( connection ); + } + else + { + then( protocol ).should( times( 1 ) ).rollbackTransaction( connection ); + } + } + + private static Stream conflictingTransactionCompletingActionArgs() + { + return Stream.of( + Arguments.of( true, true, "commit", "commit", UnmanagedTransaction.CANT_COMMIT_COMMITTED_MSG ), + Arguments.of( true, true, "commit", "rollback", UnmanagedTransaction.CANT_ROLLBACK_COMMITTED_MSG ), + Arguments.of( true, false, "commit", "rollback", UnmanagedTransaction.CANT_ROLLBACK_COMMITTING_MSG ), + Arguments.of( true, false, "commit", "close", UnmanagedTransaction.CANT_ROLLBACK_COMMITTING_MSG ), + + Arguments.of( false, true, "rollback", "rollback", UnmanagedTransaction.CANT_ROLLBACK_ROLLED_BACK_MSG ), + Arguments.of( false, true, "rollback", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLED_BACK_MSG ), + Arguments.of( false, false, "rollback", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLING_BACK_MSG ), + + Arguments.of( false, true, "close", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLED_BACK_MSG ), + Arguments.of( false, true, "close", "rollback", UnmanagedTransaction.CANT_ROLLBACK_ROLLED_BACK_MSG ), + Arguments.of( false, false, "close", "commit", UnmanagedTransaction.CANT_COMMIT_ROLLING_BACK_MSG ) + ); + } + + @ParameterizedTest + @MethodSource( "conflictingTransactionCompletingActionArgs" ) + void shouldReturnFailingStageOnConflictingCompletingAction( boolean protocolCommit, boolean protocolActionCompleted, String initialAction, + String conflictingAction, String expectedErrorMsg ) + { + Connection connection = mock( Connection.class ); + BoltProtocol protocol = mock( BoltProtocol.class ); + given( connection.protocol() ).willReturn( protocol ); + given( protocolCommit ? protocol.commitTransaction( connection ) : protocol.rollbackTransaction( connection ) ) + .willReturn( protocolActionCompleted ? completedFuture( null ) : new CompletableFuture<>() ); + UnmanagedTransaction tx = new UnmanagedTransaction( connection, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE ); + + CompletionStage originalActionStage = mapTransactionAction( initialAction, tx ).get(); + CompletionStage conflictingActionStage = mapTransactionAction( conflictingAction, tx ).get(); + + assertNotNull( originalActionStage ); + if ( protocolCommit ) + { + then( protocol ).should( times( 1 ) ).commitTransaction( connection ); + } + else + { + then( protocol ).should( times( 1 ) ).rollbackTransaction( connection ); + } + assertTrue( conflictingActionStage.toCompletableFuture().isCompletedExceptionally() ); + Throwable throwable = assertThrows( ExecutionException.class, () -> conflictingActionStage.toCompletableFuture().get() ).getCause(); + assertTrue( throwable instanceof ClientException ); + assertEquals( expectedErrorMsg, throwable.getMessage() ); + } + + private static Stream closingNotActionTransactionArgs() + { + return Stream.of( + Arguments.of( true, 1, "commit", null ), + Arguments.of( false, 1, "rollback", null ), + Arguments.of( false, 0, "terminate", null ), + Arguments.of( true, 1, "commit", true ), + Arguments.of( false, 1, "rollback", true ), + Arguments.of( true, 1, "commit", false ), + Arguments.of( false, 1, "rollback", false ), + Arguments.of( false, 0, "terminate", false ) + ); + } + + @ParameterizedTest + @MethodSource( "closingNotActionTransactionArgs" ) + void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommittingAborted( + boolean protocolCommit, int expectedProtocolInvocations, String originalAction, Boolean commitOnClose ) + { + Connection connection = mock( Connection.class ); + BoltProtocol protocol = mock( BoltProtocol.class ); + given( connection.protocol() ).willReturn( protocol ); + given( protocolCommit ? protocol.commitTransaction( connection ) : protocol.rollbackTransaction( connection ) ) + .willReturn( completedFuture( null ) ); + UnmanagedTransaction tx = new UnmanagedTransaction( connection, new DefaultBookmarkHolder(), UNLIMITED_FETCH_SIZE ); + + CompletionStage originalActionStage = mapTransactionAction( originalAction, tx ).get(); + CompletionStage closeStage = commitOnClose != null ? tx.closeAsync( commitOnClose ) : tx.closeAsync(); + + assertTrue( originalActionStage.toCompletableFuture().isDone() ); + assertFalse( originalActionStage.toCompletableFuture().isCompletedExceptionally() ); + if ( protocolCommit ) + { + then( protocol ).should( times( expectedProtocolInvocations ) ).commitTransaction( connection ); + } + else + { + then( protocol ).should( times( expectedProtocolInvocations ) ).rollbackTransaction( connection ); + } + assertNull( closeStage.toCompletableFuture().join() ); + } + private static UnmanagedTransaction beginTx( Connection connection ) { return beginTx( connection, InternalBookmark.empty() ); @@ -346,4 +487,34 @@ private ResultCursorsHolder mockResultCursorWith( ClientException clientExceptio resultCursorsHolder.add( completedFuture( cursor ) ); return resultCursorsHolder; } + + private Supplier> mapTransactionAction( String actionName, UnmanagedTransaction tx ) + { + Supplier> action; + if ( "commit".equals( actionName ) ) + { + action = tx::commitAsync; + } + else if ( "rollback".equals( actionName ) ) + { + action = tx::rollbackAsync; + } + else if ( "terminate".equals( actionName ) ) + { + action = () -> + { + tx.markTerminated( mock( Throwable.class ) ); + return completedFuture( null ); + }; + } + else if ( "close".equals( actionName ) ) + { + action = tx::closeAsync; + } + else + { + throw new RuntimeException( String.format( "Unknown completing action type '%s'", actionName ) ); + } + return action; + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java index 3f320231ab..2ca1ea21fe 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java @@ -199,9 +199,7 @@ void shouldDelegateRunTx( Function> runTx ) throws T // Given NetworkSession session = mock( NetworkSession.class ); UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( completedWithNull() ); - when( tx.rollbackAsync() ).thenReturn( completedWithNull() ); + when( tx.closeAsync( true ) ).thenReturn( completedWithNull() ); when( session.beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ) ).thenReturn( completedFuture( tx ) ); when( session.retryLogic() ).thenReturn( new FixedRetryLogic( 1 ) ); @@ -213,7 +211,7 @@ void shouldDelegateRunTx( Function> runTx ) throws T // Then verify( session ).beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ); - verify( tx ).commitAsync(); + verify( tx ).closeAsync( true ); } @Test @@ -223,25 +221,24 @@ void shouldRetryOnError() throws Throwable int retryCount = 2; NetworkSession session = mock( NetworkSession.class ); UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( completedWithNull() ); - when( tx.rollbackAsync() ).thenReturn( completedWithNull() ); + when( tx.closeAsync( false ) ).thenReturn( completedWithNull() ); when( session.beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ) ).thenReturn( completedFuture( tx ) ); when( session.retryLogic() ).thenReturn( new FixedRetryLogic( retryCount ) ); InternalRxSession rxSession = new InternalRxSession( session ); // When - Publisher strings = rxSession.readTransaction( t -> - Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ) ); + Publisher strings = rxSession.readTransaction( + t -> + Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ) ); StepVerifier.create( Flux.from( strings ) ) - // we lost the "a"s too as the user only see the last failure - .expectError( RuntimeException.class ) - .verify(); + // we lost the "a"s too as the user only see the last failure + .expectError( RuntimeException.class ) + .verify(); // Then verify( session, times( retryCount + 1 ) ).beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ); - verify( tx, times( retryCount + 1 ) ).closeAsync(); + verify( tx, times( retryCount + 1 ) ).closeAsync( false ); } @Test @@ -251,9 +248,8 @@ void shouldObtainResultIfRetrySucceed() throws Throwable int retryCount = 2; NetworkSession session = mock( NetworkSession.class ); UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( completedWithNull() ); - when( tx.rollbackAsync() ).thenReturn( completedWithNull() ); + when( tx.closeAsync( false ) ).thenReturn( completedWithNull() ); + when( tx.closeAsync( true ) ).thenReturn( completedWithNull() ); when( session.beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ) ).thenReturn( completedFuture( tx ) ); when( session.retryLogic() ).thenReturn( new FixedRetryLogic( retryCount ) ); @@ -261,23 +257,25 @@ void shouldObtainResultIfRetrySucceed() throws Throwable // When AtomicInteger count = new AtomicInteger(); - Publisher strings = rxSession.readTransaction( t -> { - // we fail for the first few retries, and then success on the last run. - if ( count.getAndIncrement() == retryCount ) - { - return Flux.just( "a" ); - } - else - { - return Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ); - } - } ); + Publisher strings = rxSession.readTransaction( + t -> + { + // we fail for the first few retries, and then success on the last run. + if ( count.getAndIncrement() == retryCount ) + { + return Flux.just( "a" ); + } + else + { + return Flux.just( "a" ).then( Mono.error( new RuntimeException( "Errored" ) ) ); + } + } ); StepVerifier.create( Flux.from( strings ) ).expectNext( "a" ).verifyComplete(); // Then verify( session, times( retryCount + 1 ) ).beginTransactionAsync( any( AccessMode.class ), any( TransactionConfig.class ) ); - verify( tx, times( retryCount ) ).closeAsync(); - verify( tx ).commitAsync(); + verify( tx, times( retryCount ) ).closeAsync( false ); + verify( tx ).closeAsync( true ); } @Test diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java index 1accde96db..5a9f0bb4b6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxTransactionTest.java @@ -48,7 +48,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.neo4j.driver.Values.parameters; @@ -140,43 +139,28 @@ void shouldMarkTxIfFailedToRun( Function runReturnOne ) } @Test - void shouldCommitWhenOpen() + void shouldDelegateConditionalClose() { UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( true ); - when( tx.commitAsync() ).thenReturn( Futures.completedWithNull() ); - - InternalRxTransaction rxTx = new InternalRxTransaction( tx ); - Publisher publisher = rxTx.commitIfOpen(); - StepVerifier.create( publisher ).verifyComplete(); - - verify( tx ).commitAsync(); - } - - @Test - void shouldNotCommitWhenNotOpen() - { - UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.isOpen() ).thenReturn( false ); - when( tx.commitAsync() ).thenReturn( Futures.completedWithNull() ); + when( tx.closeAsync( true ) ).thenReturn( Futures.completedWithNull() ); InternalRxTransaction rxTx = new InternalRxTransaction( tx ); - Publisher publisher = rxTx.commitIfOpen(); + Publisher publisher = rxTx.close( true ); StepVerifier.create( publisher ).verifyComplete(); - verify( tx, never() ).commitAsync(); + verify( tx ).closeAsync( true ); } @Test void shouldDelegateClose() { UnmanagedTransaction tx = mock( UnmanagedTransaction.class ); - when( tx.closeAsync() ).thenReturn( Futures.completedWithNull() ); + when( tx.closeAsync( false ) ).thenReturn( Futures.completedWithNull() ); InternalRxTransaction rxTx = new InternalRxTransaction( tx ); Publisher publisher = rxTx.close(); StepVerifier.create( publisher ).verifyComplete(); - verify( tx ).closeAsync(); + verify( tx ).closeAsync( false ); } }