Skip to content

Commit c3b64dc

Browse files
artembilangaryrussell
authored andcommitted
INT-4366: Fix MulticastSendingMessageHandler (#2329)
* INT-4366: Fix MulticastSendingMessageHandler JIRA: https://jira.spring.io/browse/INT-4366 Fix race condition in the `MulticastSendingMessageHandler` around `multicastSocket` and super `socket` properties. * Synchronize around `this` and check for the `multicastSocket == null`. This let the `MulticastSendingMessageHandler` to fully configure and prepare the socket for use. * Remove `socket.setInterface(whichNic)` since it is populated by the `InetSocketAddress` ctor before **Cherry-pick to 4.3.x** * Fix thread leaks in TCP/IP tests
1 parent 8aa91d1 commit c3b64dc

13 files changed

+261
-235
lines changed

spring-integration-ip/src/main/java/org/springframework/integration/ip/udp/MulticastSendingMessageHandler.java

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2001-2016 the original author or authors.
2+
* Copyright 2001-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -38,6 +38,8 @@
3838
* determine success.
3939
*
4040
* @author Gary Russell
41+
* @author Artem Bilan
42+
*
4143
* @since 2.0
4244
*/
4345
public class MulticastSendingMessageHandler extends UnicastSendingMessageHandler {
@@ -126,49 +128,45 @@ public MulticastSendingMessageHandler(String destinationExpression) {
126128

127129
@Override
128130
protected DatagramSocket getSocket() throws IOException {
129-
if (this.getTheSocket() == null) {
131+
if (this.multicastSocket == null) {
130132
synchronized (this) {
131-
createSocket();
133+
if (this.multicastSocket == null) {
134+
createSocket();
135+
}
132136
}
133137
}
134-
return this.getTheSocket();
138+
return getTheSocket();
135139
}
136140

137141
private void createSocket() throws IOException {
138-
if (this.getTheSocket() == null) {
139-
MulticastSocket socket;
140-
if (this.isAcknowledge()) {
141-
int ackPort = this.getAckPort();
142-
if (this.localAddress == null) {
143-
socket = ackPort == 0 ? new MulticastSocket() : new MulticastSocket(ackPort);
144-
}
145-
else {
146-
InetAddress whichNic = InetAddress.getByName(this.localAddress);
147-
socket = new MulticastSocket(new InetSocketAddress(whichNic, ackPort));
148-
}
149-
if (getSoReceiveBufferSize() > 0) {
150-
socket.setReceiveBufferSize(this.getSoReceiveBufferSize());
151-
}
152-
if (logger.isDebugEnabled()) {
153-
logger.debug("Listening for acks on port: " + socket.getLocalPort());
154-
}
155-
setSocket(socket);
156-
updateAckAddress();
142+
MulticastSocket socket;
143+
if (isAcknowledge()) {
144+
int ackPort = getAckPort();
145+
if (this.localAddress == null) {
146+
socket = ackPort == 0 ? new MulticastSocket() : new MulticastSocket(ackPort);
157147
}
158148
else {
159-
socket = new MulticastSocket();
160-
setSocket(socket);
149+
InetAddress whichNic = InetAddress.getByName(this.localAddress);
150+
socket = new MulticastSocket(new InetSocketAddress(whichNic, ackPort));
161151
}
162-
if (this.timeToLive >= 0) {
163-
socket.setTimeToLive(this.timeToLive);
152+
if (getSoReceiveBufferSize() > 0) {
153+
socket.setReceiveBufferSize(getSoReceiveBufferSize());
164154
}
165-
setSocketAttributes(socket);
166-
if (this.localAddress != null) {
167-
InetAddress whichNic = InetAddress.getByName(this.localAddress);
168-
socket.setInterface(whichNic);
155+
if (logger.isDebugEnabled()) {
156+
logger.debug("Listening for acks on port: " + socket.getLocalPort());
169157
}
170-
this.multicastSocket = socket;
158+
setSocket(socket);
159+
updateAckAddress();
160+
}
161+
else {
162+
socket = new MulticastSocket();
163+
setSocket(socket);
164+
}
165+
if (this.timeToLive >= 0) {
166+
socket.setTimeToLive(this.timeToLive);
171167
}
168+
setSocketAttributes(socket);
169+
this.multicastSocket = socket;
172170
}
173171

174172

@@ -178,7 +176,7 @@ private void createSocket() throws IOException {
178176
* @param minAcksForSuccess The minimum number of acks that will represent success.
179177
*/
180178
public void setMinAcksForSuccess(int minAcksForSuccess) {
181-
this.setAckCounter(minAcksForSuccess);
179+
setAckCounter(minAcksForSuccess);
182180
}
183181

184182
/**

spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpInboundGatewayTests.java

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2016 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -28,7 +28,6 @@
2828
import java.util.HashSet;
2929
import java.util.Set;
3030
import java.util.concurrent.CountDownLatch;
31-
import java.util.concurrent.Executors;
3231
import java.util.concurrent.TimeUnit;
3332
import java.util.concurrent.atomic.AtomicBoolean;
3433
import java.util.concurrent.atomic.AtomicInteger;
@@ -39,6 +38,7 @@
3938
import org.junit.Test;
4039

4140
import org.springframework.beans.factory.BeanFactory;
41+
import org.springframework.core.task.SimpleAsyncTaskExecutor;
4242
import org.springframework.integration.channel.DirectChannel;
4343
import org.springframework.integration.channel.QueueChannel;
4444
import org.springframework.integration.handler.ServiceActivatingHandler;
@@ -56,6 +56,8 @@
5656

5757
/**
5858
* @author Gary Russell
59+
* @author Artem Bilan
60+
*
5961
* @since 2.0
6062
*/
6163
public class TcpInboundGatewayTests {
@@ -119,30 +121,31 @@ public void testNetClientMode() throws Exception {
119121
final CountDownLatch latch2 = new CountDownLatch(1);
120122
final CountDownLatch latch3 = new CountDownLatch(1);
121123
final AtomicBoolean done = new AtomicBoolean();
122-
Executors.newSingleThreadExecutor().execute(() -> {
123-
try {
124-
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
125-
port.set(server.getLocalPort());
126-
latch1.countDown();
127-
Socket socket = server.accept();
128-
socket.getOutputStream().write("Test1\r\nTest2\r\n".getBytes());
129-
byte[] bytes = new byte[12];
130-
readFully(socket.getInputStream(), bytes);
131-
assertEquals("Echo:Test1\r\n", new String(bytes));
132-
readFully(socket.getInputStream(), bytes);
133-
assertEquals("Echo:Test2\r\n", new String(bytes));
134-
latch2.await();
135-
socket.close();
136-
server.close();
137-
done.set(true);
138-
latch3.countDown();
139-
}
140-
catch (Exception e) {
141-
if (!done.get()) {
142-
e.printStackTrace();
143-
}
144-
}
145-
});
124+
new SimpleAsyncTaskExecutor()
125+
.execute(() -> {
126+
try {
127+
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
128+
port.set(server.getLocalPort());
129+
latch1.countDown();
130+
Socket socket = server.accept();
131+
socket.getOutputStream().write("Test1\r\nTest2\r\n".getBytes());
132+
byte[] bytes = new byte[12];
133+
readFully(socket.getInputStream(), bytes);
134+
assertEquals("Echo:Test1\r\n", new String(bytes));
135+
readFully(socket.getInputStream(), bytes);
136+
assertEquals("Echo:Test2\r\n", new String(bytes));
137+
latch2.await();
138+
socket.close();
139+
server.close();
140+
done.set(true);
141+
latch3.countDown();
142+
}
143+
catch (Exception e) {
144+
if (!done.get()) {
145+
e.printStackTrace();
146+
}
147+
}
148+
});
146149
assertTrue(latch1.await(10, TimeUnit.SECONDS));
147150
AbstractClientConnectionFactory ccf = new TcpNetClientConnectionFactory("localhost", port.get());
148151
ccf.setSingleUse(false);

spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import java.util.Set;
4444
import java.util.concurrent.CountDownLatch;
4545
import java.util.concurrent.ExecutionException;
46-
import java.util.concurrent.Executors;
4746
import java.util.concurrent.Future;
4847
import java.util.concurrent.TimeUnit;
4948
import java.util.concurrent.atomic.AtomicBoolean;
@@ -61,6 +60,8 @@
6160
import org.springframework.beans.factory.BeanFactory;
6261
import org.springframework.core.serializer.DefaultDeserializer;
6362
import org.springframework.core.serializer.DefaultSerializer;
63+
import org.springframework.core.task.AsyncTaskExecutor;
64+
import org.springframework.core.task.SimpleAsyncTaskExecutor;
6465
import org.springframework.expression.EvaluationContext;
6566
import org.springframework.expression.Expression;
6667
import org.springframework.expression.spel.standard.SpelExpressionParser;
@@ -90,6 +91,8 @@ public class TcpOutboundGatewayTests {
9091

9192
private static final Log logger = LogFactory.getLog(TcpOutboundGatewayTests.class);
9293

94+
private AsyncTaskExecutor executor = new SimpleAsyncTaskExecutor();
95+
9396
@ClassRule
9497
public static LongRunningIntegrationTest longTests = new LongRunningIntegrationTest();
9598

@@ -101,13 +104,13 @@ public class TcpOutboundGatewayTests {
101104
public void testGoodNetSingle() throws Exception {
102105
final CountDownLatch latch = new CountDownLatch(1);
103106
final AtomicBoolean done = new AtomicBoolean();
104-
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
105-
Executors.newSingleThreadExecutor().execute(() -> {
107+
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
108+
this.executor.execute(() -> {
106109
try {
107110
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 100);
108111
serverSocket.set(server);
109112
latch.countDown();
110-
List<Socket> sockets = new ArrayList<Socket>();
113+
List<Socket> sockets = new ArrayList<>();
111114
int i = 0;
112115
while (true) {
113116
Socket socket = server.accept();
@@ -165,8 +168,8 @@ public void testGoodNetSingle() throws Exception {
165168
public void testGoodNetMultiplex() throws Exception {
166169
final CountDownLatch latch = new CountDownLatch(1);
167170
final AtomicBoolean done = new AtomicBoolean();
168-
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
169-
Executors.newSingleThreadExecutor().execute(() -> {
171+
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
172+
this.executor.execute(() -> {
170173
try {
171174
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0, 10);
172175
serverSocket.set(server);
@@ -220,8 +223,8 @@ public void testGoodNetMultiplex() throws Exception {
220223
public void testGoodNetTimeout() throws Exception {
221224
final CountDownLatch latch = new CountDownLatch(1);
222225
final AtomicBoolean done = new AtomicBoolean();
223-
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
224-
Executors.newSingleThreadExecutor().execute(() -> {
226+
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
227+
this.executor.execute(() -> {
225228
try {
226229
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
227230
serverSocket.set(server);
@@ -260,12 +263,12 @@ public void testGoodNetTimeout() throws Exception {
260263
Future<Integer>[] results = (Future<Integer>[]) new Future<?>[2];
261264
for (int i = 0; i < 2; i++) {
262265
final int j = i;
263-
results[j] = (Executors.newSingleThreadExecutor().submit(() -> {
266+
results[j] = (this.executor.submit(() -> {
264267
gateway.handleMessage(MessageBuilder.withPayload("Test" + j).build());
265268
return 0;
266269
}));
267270
}
268-
Set<String> replies = new HashSet<String>();
271+
Set<String> replies = new HashSet<>();
269272
int timeouts = 0;
270273
for (int i = 0; i < 2; i++) {
271274
try {
@@ -344,7 +347,7 @@ private void testGoodNetGWTimeoutGuts(final int port, AbstractClientConnectionFa
344347
final AtomicReference<String> lastReceived = new AtomicReference<String>();
345348
final CountDownLatch serverLatch = new CountDownLatch(2);
346349

347-
Executors.newSingleThreadExecutor().execute(() -> {
350+
this.executor.execute(() -> {
348351
try {
349352
latch.countDown();
350353
int i = 0;
@@ -398,7 +401,7 @@ private void testGoodNetGWTimeoutGuts(final int port, AbstractClientConnectionFa
398401

399402
for (int i = 0; i < 2; i++) {
400403
final int j = i;
401-
results[j] = (Executors.newSingleThreadExecutor().submit(() -> {
404+
results[j] = (this.executor.submit(() -> {
402405
gateway.handleMessage(MessageBuilder.withPayload("Test" + j).build());
403406
return j;
404407
}));
@@ -442,7 +445,7 @@ public void testCachingFailover() throws Exception {
442445
final AtomicBoolean done = new AtomicBoolean();
443446
final CountDownLatch serverLatch = new CountDownLatch(1);
444447

445-
Executors.newSingleThreadExecutor().execute(() -> {
448+
this.executor.execute(() -> {
446449
try {
447450
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
448451
serverSocket.set(server);
@@ -517,12 +520,12 @@ public void testCachingFailover() throws Exception {
517520

518521
@Test
519522
public void testFailoverCached() throws Exception {
520-
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<ServerSocket>();
523+
final AtomicReference<ServerSocket> serverSocket = new AtomicReference<>();
521524
final CountDownLatch latch = new CountDownLatch(1);
522525
final AtomicBoolean done = new AtomicBoolean();
523526
final CountDownLatch serverLatch = new CountDownLatch(1);
524527

525-
Executors.newSingleThreadExecutor().execute(() -> {
528+
this.executor.execute(() -> {
526529
try {
527530
ServerSocket server = ServerSocketFactory.getDefault().createServerSocket(0);
528531
serverSocket.set(server);
@@ -667,11 +670,11 @@ private void testGWPropagatesSocketCloseGuts(final int port, AbstractClientConne
667670
final ServerSocket server) throws Exception {
668671
final CountDownLatch latch = new CountDownLatch(1);
669672
final AtomicBoolean done = new AtomicBoolean();
670-
final AtomicReference<String> lastReceived = new AtomicReference<String>();
673+
final AtomicReference<String> lastReceived = new AtomicReference<>();
671674
final CountDownLatch serverLatch = new CountDownLatch(1);
672675

673-
Executors.newSingleThreadExecutor().execute(() -> {
674-
List<Socket> sockets = new ArrayList<Socket>();
676+
this.executor.execute(() -> {
677+
List<Socket> sockets = new ArrayList<>();
675678
try {
676679
latch.countDown();
677680
while (!done.get()) {
@@ -793,8 +796,8 @@ private void testGWPropagatesSocketTimeoutGuts(final int port, AbstractClientCon
793796
final CountDownLatch latch = new CountDownLatch(1);
794797
final AtomicBoolean done = new AtomicBoolean();
795798

796-
Executors.newSingleThreadExecutor().execute(() -> {
797-
List<Socket> sockets = new ArrayList<Socket>();
799+
this.executor.execute(() -> {
800+
List<Socket> sockets = new ArrayList<>();
798801
try {
799802
latch.countDown();
800803
while (!done.get()) {

0 commit comments

Comments
 (0)