diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java index ac914911ae1..d50760b91cb 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java @@ -16,20 +16,17 @@ package org.springframework.integration.channel; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.springframework.messaging.Message; import org.springframework.util.Assert; -import reactor.core.publisher.ConnectableFlux; +import reactor.core.publisher.EmitterProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; +import reactor.core.publisher.ReplayProcessor; +import reactor.core.scheduler.Schedulers; /** * The {@link AbstractMessageChannel} implementation for the @@ -37,30 +34,27 @@ * * @author Artem Bilan * @author Gary Russell + * @author Sergei Egorov * * @since 5.0 */ public class FluxMessageChannel extends AbstractMessageChannel implements Publisher>, ReactiveStreamsSubscribableChannel { - private final List>> subscribers = new ArrayList<>(); - - private final Map>, ConnectableFlux> publishers = new ConcurrentHashMap<>(); + private final EmitterProcessor> processor; - private final Flux> flux; + private final FluxSink> sink; - private FluxSink> sink; + private final ReplayProcessor subscribedSignal = ReplayProcessor.create(1); public FluxMessageChannel() { - this.flux = - Flux.>create(emitter -> this.sink = emitter, FluxSink.OverflowStrategy.IGNORE) - .publish() - .autoConnect(); + this.processor = EmitterProcessor.create(1, false); + this.sink = this.processor.sink(FluxSink.OverflowStrategy.BUFFER); } @Override protected boolean doSend(Message message, long timeout) { - Assert.state(this.subscribers.size() > 0, + Assert.state(this.processor.hasDownstreams(), () -> "The [" + this + "] doesn't have subscribers to accept messages"); this.sink.next(message); return true; @@ -68,30 +62,33 @@ protected boolean doSend(Message message, long timeout) { @Override public void subscribe(Subscriber> subscriber) { - this.subscribers.add(subscriber); - - this.flux.doOnCancel(() -> this.subscribers.remove(subscriber)) - .retry() + this.processor + .doFinally((s) -> this.subscribedSignal.onNext(this.processor.hasDownstreams())) .subscribe(subscriber); - - this.publishers.values().forEach(ConnectableFlux::connect); + this.subscribedSignal.onNext(this.processor.hasDownstreams()); } @Override public void subscribeTo(Publisher> publisher) { - ConnectableFlux connectableFlux = - Flux.from(publisher) - .handle((message, sink) -> sink.next(send(message))) - .onErrorContinue((throwable, event) -> - logger.warn("Error during processing event: " + event, throwable)) - .doOnComplete(() -> this.publishers.remove(publisher)) - .publish(); - - this.publishers.put(publisher, connectableFlux); + Flux.from(publisher) + .delaySubscription(this.subscribedSignal.filter(Boolean::booleanValue).next()) + .publishOn(Schedulers.boundedElastic()) + .doOnNext((message) -> { + try { + send(message); + } + catch (Exception e) { + logger.warn("Error during processing event: " + message, e); + } + }) + .subscribe(); + } - if (!this.subscribers.isEmpty()) { - connectableFlux.connect(); - } + @Override + public void destroy() { + this.subscribedSignal.onNext(false); + this.processor.onComplete(); + super.destroy(); } } diff --git a/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java b/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java index 6a254d0c736..c68f3dce374 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java @@ -20,17 +20,17 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.junit.Test; -import org.junit.runner.RunWith; +import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.integration.annotation.BridgeFrom; import org.springframework.integration.annotation.ServiceActivator; import org.springframework.integration.channel.FluxMessageChannel; import org.springframework.integration.channel.MessageChannelReactiveUtils; @@ -47,8 +47,10 @@ import org.springframework.messaging.support.GenericMessage; import org.springframework.messaging.support.MessageBuilder; import org.springframework.test.annotation.DirtiesContext; -import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import reactor.core.Disposable; +import reactor.core.publisher.EmitterProcessor; import reactor.core.publisher.Flux; /** @@ -56,7 +58,7 @@ * * @since 5.0 */ -@RunWith(SpringRunner.class) +@SpringJUnitConfig @DirtiesContext public class FluxMessageChannelTests { @@ -64,7 +66,7 @@ public class FluxMessageChannelTests { private MessageChannel fluxMessageChannel; @Autowired - private MessageChannel queueChannel; + private QueueChannel queueChannel; @Autowired private PollableChannel errorChannel; @@ -73,7 +75,7 @@ public class FluxMessageChannelTests { private IntegrationFlowContext integrationFlowContext; @Test - public void testFluxMessageChannel() { + void testFluxMessageChannel() { QueueChannel replyChannel = new QueueChannel(); for (int i = 0; i < 10; i++) { @@ -90,28 +92,35 @@ public void testFluxMessageChannel() { Message error = this.errorChannel.receive(0); assertThat(error).isNotNull(); assertThat(((MessagingException) error.getPayload()).getFailedMessage().getPayload()).isEqualTo(5); + + List> messages = this.queueChannel.clear(); + assertThat(messages).extracting((message) -> (Integer) message.getPayload()) + .containsAll(IntStream.range(0, 10).boxed().collect(Collectors.toList())); } @Test - public void testMessageChannelReactiveAdaptation() throws InterruptedException { + void testMessageChannelReactiveAdaptation() throws InterruptedException { CountDownLatch done = new CountDownLatch(2); List results = new ArrayList<>(); - Flux.from(MessageChannelReactiveUtils.toPublisher(this.queueChannel)) - .map(Message::getPayload) - .map(String::toUpperCase) - .doOnNext(results::add) - .subscribe(v -> done.countDown()); + Disposable disposable = + Flux.from(MessageChannelReactiveUtils.toPublisher(this.queueChannel)) + .map(Message::getPayload) + .map(String::toUpperCase) + .doOnNext(results::add) + .subscribe(v -> done.countDown()); this.queueChannel.send(new GenericMessage<>("foo")); this.queueChannel.send(new GenericMessage<>("bar")); assertThat(done.await(10, TimeUnit.SECONDS)).isTrue(); assertThat(results).containsExactly("FOO", "BAR"); + + disposable.dispose(); } @Test - public void testFluxMessageChannelCleanUp() throws InterruptedException { + void testFluxMessageChannelCleanUp() throws InterruptedException { FluxMessageChannel flux = MessageChannels.flux().get(); CountDownLatch finishLatch = new CountDownLatch(1); @@ -130,9 +139,9 @@ public void testFluxMessageChannelCleanUp() throws InterruptedException { assertThat(finishLatch.await(10, TimeUnit.SECONDS)).isTrue(); - assertThat(TestUtils.getPropertyValue(flux, "publishers", Map.class).isEmpty()).isTrue(); - flowRegistration.destroy(); + + assertThat(TestUtils.getPropertyValue(flux, "processor", EmitterProcessor.class).isTerminated()).isTrue(); } @Configuration @@ -158,6 +167,7 @@ public String handle(int payload) { } @Bean + @BridgeFrom("fluxMessageChannel") public MessageChannel queueChannel() { return new QueueChannel(); }