Skip to content

Commit ce7692a

Browse files
committed
Refactored out blocking read and writes
1 parent 0a9ebd5 commit ce7692a

File tree

3 files changed

+185
-19
lines changed

3 files changed

+185
-19
lines changed

driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import org.neo4j.driver.v1.exceptions.ClientException;
3636

3737
import static java.nio.ByteOrder.BIG_ENDIAN;
38+
import static org.neo4j.driver.internal.connector.socket.SocketUtils.blockingRead;
39+
import static org.neo4j.driver.internal.connector.socket.SocketUtils.blockingWrite;
3840

3941
public class SocketClient
4042
{
@@ -178,29 +180,12 @@ private SocketProtocol negotiateProtocol() throws IOException
178180
buf.flip();
179181

180182
//Do a blocking write
181-
while(buf.hasRemaining())
182-
{
183-
if (channel.write( buf ) < 0)
184-
{
185-
throw new ClientException(
186-
"Connection terminated while proposing protocol. This can happen due to network " +
187-
"instabilities, or due to restarts of the database." );
188-
}
189-
}
183+
blockingWrite(channel, buf);
190184

191185
// Read (blocking) back the servers choice
192186
buf.clear();
193187
buf.limit( 4 );
194-
195-
while(buf.hasRemaining())
196-
{
197-
if ( channel.read( buf ) < 0 )
198-
{
199-
throw new ClientException(
200-
"Connection terminated while negotiating protocol. This can happen due to network " +
201-
"instabilities, or due to restarts of the database." );
202-
}
203-
}
188+
blockingRead(channel, buf);
204189

205190
// Choose protocol, or fail
206191
buf.flip();
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package org.neo4j.driver.internal.connector.socket;
2+
3+
import java.io.IOException;
4+
import java.nio.ByteBuffer;
5+
import java.nio.channels.ByteChannel;
6+
7+
import org.neo4j.driver.internal.util.BytePrinter;
8+
import org.neo4j.driver.v1.exceptions.ClientException;
9+
10+
/**
11+
* Utility class for common operations.
12+
*/
13+
public final class SocketUtils
14+
{
15+
private SocketUtils()
16+
{
17+
throw new UnsupportedOperationException( "Do not instantiate" );
18+
}
19+
20+
public static void blockingRead(ByteChannel channel, ByteBuffer buf) throws IOException
21+
{
22+
while(buf.hasRemaining())
23+
{
24+
if (channel.read( buf ) < 0)
25+
{
26+
throw new ClientException( String.format(
27+
"Connection terminated while receiving data. This can happen due to network " +
28+
"instabilities, or due to restarts of the database. Expected %s bytes, received %s.",
29+
buf.limit(), BytePrinter.hex( buf ) ) );
30+
}
31+
}
32+
}
33+
34+
public static void blockingWrite(ByteChannel channel, ByteBuffer buf) throws IOException
35+
{
36+
while(buf.hasRemaining())
37+
{
38+
if (channel.write( buf ) < 0)
39+
{
40+
throw new ClientException( String.format(
41+
"Connection terminated while sending data. This can happen due to network " +
42+
"instabilities, or due to restarts of the database. Expected %s bytes, wrote %s.",
43+
buf.limit(), BytePrinter.hex( buf ) ) );
44+
}
45+
}
46+
}
47+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package org.neo4j.driver.internal.connector.socket;
2+
3+
import org.junit.Rule;
4+
import org.junit.Test;
5+
import org.junit.rules.ExpectedException;
6+
7+
import java.io.IOException;
8+
import java.nio.ByteBuffer;
9+
import java.nio.channels.ByteChannel;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
13+
import org.neo4j.driver.v1.exceptions.ClientException;
14+
15+
import static org.hamcrest.CoreMatchers.equalTo;
16+
import static org.hamcrest.MatcherAssert.assertThat;
17+
import static org.mockito.Mockito.mock;
18+
import static org.mockito.Mockito.when;
19+
20+
public class SocketUtilsTest
21+
{
22+
@Rule
23+
public ExpectedException exception = ExpectedException.none();
24+
25+
@Test
26+
public void shouldReadAllBytes() throws IOException
27+
{
28+
// Given
29+
ByteBuffer buffer = ByteBuffer.allocate( 4 );
30+
ByteAtATimeChannel channel = new ByteAtATimeChannel( new byte[]{0, 1, 2, 3} );
31+
32+
// When
33+
SocketUtils.blockingRead(channel, buffer );
34+
buffer.flip();
35+
36+
// Then
37+
assertThat(buffer.get(), equalTo((byte) 0));
38+
assertThat(buffer.get(), equalTo((byte) 1));
39+
assertThat(buffer.get(), equalTo((byte) 2));
40+
assertThat(buffer.get(), equalTo((byte) 3));
41+
}
42+
43+
@Test
44+
public void shouldFailIfConnectionFailsWhileReading() throws IOException
45+
{
46+
// Given
47+
ByteBuffer buffer = ByteBuffer.allocate( 4 );
48+
ByteChannel channel = mock( ByteChannel.class );
49+
when(channel.read( buffer )).thenReturn( -1 );
50+
51+
//Expect
52+
exception.expect( ClientException.class );
53+
54+
// When
55+
SocketUtils.blockingRead(channel, buffer );
56+
}
57+
58+
@Test
59+
public void shouldWriteAllBytes() throws IOException
60+
{
61+
// Given
62+
ByteBuffer buffer = ByteBuffer.wrap( new byte[]{0, 1, 2, 3});
63+
ByteAtATimeChannel channel = new ByteAtATimeChannel( new byte[0] );
64+
65+
// When
66+
SocketUtils.blockingWrite(channel, buffer );
67+
68+
// Then
69+
assertThat(channel.writtenBytes.get(0), equalTo((byte) 0));
70+
assertThat(channel.writtenBytes.get(1), equalTo((byte) 1));
71+
assertThat(channel.writtenBytes.get(2), equalTo((byte) 2));
72+
assertThat(channel.writtenBytes.get(3), equalTo((byte) 3));
73+
}
74+
75+
@Test
76+
public void shouldFailIfConnectionFailsWhileWriting() throws IOException
77+
{
78+
// Given
79+
ByteBuffer buffer = ByteBuffer.allocate( 4 );
80+
ByteChannel channel = mock( ByteChannel.class );
81+
when(channel.write( buffer )).thenReturn( -1 );
82+
83+
//Expect
84+
exception.expect( ClientException.class );
85+
86+
// When
87+
SocketUtils.blockingWrite(channel, buffer );
88+
}
89+
90+
private static class ByteAtATimeChannel implements ByteChannel
91+
{
92+
93+
private final byte[] bytes;
94+
private int index = 0;
95+
private List<Byte> writtenBytes = new ArrayList<>( );
96+
97+
private ByteAtATimeChannel( byte[] bytes )
98+
{
99+
this.bytes = bytes;
100+
}
101+
102+
@Override
103+
public int read( ByteBuffer dst ) throws IOException
104+
{
105+
if (index >= bytes.length)
106+
{
107+
return -1;
108+
}
109+
110+
dst.put( bytes[index++]);
111+
return 1;
112+
}
113+
114+
@Override
115+
public int write( ByteBuffer src ) throws IOException
116+
{
117+
writtenBytes.add( src.get() );
118+
return 1;
119+
}
120+
121+
@Override
122+
public boolean isOpen()
123+
{
124+
return true;
125+
}
126+
127+
@Override
128+
public void close() throws IOException
129+
{
130+
131+
}
132+
}
133+
134+
}

0 commit comments

Comments
 (0)