Skip to content

Apply Nullability to spring-integration-websocket module #10195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.jspecify.annotations.Nullable;

import org.springframework.context.Lifecycle;
import org.springframework.context.SmartLifecycle;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -60,11 +62,12 @@ public final class ClientWebSocketContainer extends IntegrationWebSocketContaine

private int connectionTimeout = DEFAULT_CONNECTION_TIMEOUT;

@SuppressWarnings("NullAway.Init")
private volatile CountDownLatch connectionLatch;

private volatile WebSocketSession clientSession;
private volatile @Nullable WebSocketSession clientSession;

private volatile Throwable openConnectionException;
private volatile @Nullable Throwable openConnectionException;

private volatile boolean connecting;

Expand Down Expand Up @@ -121,7 +124,7 @@ public void setConnectionTimeout(int connectionTimeout) {
* @return the {@link #clientSession}, if established.
*/
@Override
public WebSocketSession getSession(String sessionId) {
public WebSocketSession getSession(@Nullable String sessionId) {
if (isRunning()) {
if (!isConnected() && !this.connecting) {
stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;

import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.socket.CloseStatus;
Expand Down Expand Up @@ -80,14 +80,13 @@ public abstract class IntegrationWebSocketContainer implements DisposableBean {

private final List<String> supportedProtocols = new ArrayList<>();

private WebSocketListener messageListener;
private @Nullable WebSocketListener messageListener;

private int sendTimeLimit = DEFAULT_SEND_TIME_LIMIT;

private int sendBufferSizeLimit = DEFAULT_SEND_BUFFER_SIZE;

@Nullable
private ConcurrentWebSocketSessionDecorator.OverflowStrategy sendBufferOverflowStrategy;
private ConcurrentWebSocketSessionDecorator.@Nullable OverflowStrategy sendBufferOverflowStrategy;

public void setSendTimeLimit(int sendTimeLimit) {
this.sendTimeLimit = sendTimeLimit;
Expand All @@ -107,7 +106,7 @@ public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
* @see ConcurrentWebSocketSessionDecorator
*/
public void setSendBufferOverflowStrategy(
@Nullable ConcurrentWebSocketSessionDecorator.OverflowStrategy overflowStrategy) {
ConcurrentWebSocketSessionDecorator.@Nullable OverflowStrategy overflowStrategy) {

this.sendBufferOverflowStrategy = overflowStrategy;
}
Expand Down Expand Up @@ -155,7 +154,7 @@ public Map<String, WebSocketSession> getSessions() {
return Collections.unmodifiableMap(this.sessions);
}

public WebSocketSession getSession(String sessionId) {
public WebSocketSession getSession(@Nullable String sessionId) {
WebSocketSession session = this.sessions.get(sessionId);
Assert.notNull(session, () -> "Session not found for id '" + sessionId + "'");
return session;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.Arrays;

import org.jspecify.annotations.Nullable;

import org.springframework.context.Lifecycle;
import org.springframework.context.SmartLifecycle;
import org.springframework.integration.JavaUtils;
Expand Down Expand Up @@ -48,6 +50,7 @@
* @author Artem Bilan
* @author Gary Russell
* @author Christian Tzolov
* @author Jooyoung Pyoung
*
* @since 4.1
*/
Expand All @@ -56,21 +59,21 @@ public class ServerWebSocketContainer extends IntegrationWebSocketContainer

private final String[] paths;

private HandshakeHandler handshakeHandler;
private @Nullable HandshakeHandler handshakeHandler;

private HandshakeInterceptor[] interceptors;
private HandshakeInterceptor[] interceptors = {};

private WebSocketHandlerDecoratorFactory[] decoratorFactories;
private WebSocketHandlerDecoratorFactory @Nullable [] decoratorFactories;

private SockJsServiceOptions sockJsServiceOptions;
private @Nullable SockJsServiceOptions sockJsServiceOptions;

private String[] origins;
private String[] origins = {};

private boolean autoStartup = true;

private int phase = 0;

private TaskScheduler sockJsTaskScheduler;
private @Nullable TaskScheduler sockJsTaskScheduler;

public ServerWebSocketContainer(String... paths) {
Assert.notEmpty(paths, "'paths' must not be empty");
Expand Down Expand Up @@ -144,7 +147,7 @@ public void setSockJsTaskScheduler(TaskScheduler sockJsTaskScheduler) {
this.sockJsTaskScheduler = sockJsTaskScheduler;
}

public TaskScheduler getSockJsTaskScheduler() {
public @Nullable TaskScheduler getSockJsTaskScheduler() {
return this.sockJsTaskScheduler;
}

Expand All @@ -159,19 +162,25 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
}
}

Assert.notNull(this.handshakeHandler, "'handshakeHandler' must not be null");
WebSocketHandlerRegistration registration = registry.addHandler(webSocketHandler, this.paths)
.setHandshakeHandler(this.handshakeHandler)
.addInterceptors(this.interceptors)
.setAllowedOrigins(this.origins);
.setHandshakeHandler(this.handshakeHandler);

configureRegistration(registration);
configureSockJsOptionsIfAny(registration);
}

private void configureRegistration(WebSocketHandlerRegistration registration) {
registration.addInterceptors(this.interceptors);
registration.setAllowedOrigins(this.origins);
}

private void configureSockJsOptionsIfAny(WebSocketHandlerRegistration registration) {
if (this.sockJsServiceOptions != null) {
SockJsServiceRegistration sockJsServiceRegistration = registration.withSockJS();
JavaUtils.INSTANCE
.acceptIfCondition(this.sockJsServiceOptions.taskScheduler == null,
.acceptIfCondition(this.sockJsServiceOptions.taskScheduler == null &&
this.sockJsTaskScheduler != null,
this.sockJsTaskScheduler, this.sockJsServiceOptions::setTaskScheduler)
.acceptIfNotNull(this.sockJsServiceOptions.webSocketEnabled,
sockJsServiceRegistration::setWebSocketEnabled)
Expand Down Expand Up @@ -227,8 +236,8 @@ public boolean isRunning() {
public void start() {
this.lock.lock();
try {
if (this.handshakeHandler instanceof Lifecycle && !isRunning()) {
((Lifecycle) this.handshakeHandler).start();
if (this.handshakeHandler instanceof Lifecycle lifeCycleHandler && !isRunning()) {
lifeCycleHandler.start();
}
}
finally {
Expand All @@ -238,15 +247,15 @@ public void start() {

@Override
public void stop() {
if (isRunning()) {
((Lifecycle) this.handshakeHandler).stop();
if (this.handshakeHandler instanceof Lifecycle lifeCycleHandler && isRunning()) {
lifeCycleHandler.stop();
}
}

@Override
public void stop(Runnable callback) {
if (isRunning()) {
((Lifecycle) this.handshakeHandler).stop();
if (this.handshakeHandler instanceof Lifecycle lifeCycleHandler && isRunning()) {
lifeCycleHandler.stop();
}
callback.run();
}
Expand All @@ -256,27 +265,27 @@ public void stop(Runnable callback) {
*/
public static class SockJsServiceOptions {

private TaskScheduler taskScheduler;
private @Nullable TaskScheduler taskScheduler;

private String clientLibraryUrl;
private @Nullable String clientLibraryUrl;

private Integer streamBytesLimit;
private @Nullable Integer streamBytesLimit;

private Boolean sessionCookieNeeded;
private @Nullable Boolean sessionCookieNeeded;

private Long heartbeatTime;
private @Nullable Long heartbeatTime;

private Long disconnectDelay;
private @Nullable Long disconnectDelay;

private Integer httpMessageCacheSize;
private @Nullable Integer httpMessageCacheSize;

private Boolean webSocketEnabled;
private @Nullable Boolean webSocketEnabled;

private TransportHandler[] transportHandlers;
private TransportHandler @Nullable [] transportHandlers;

private SockJsMessageCodec messageCodec;
private @Nullable SockJsMessageCodec messageCodec;

private Boolean suppressCors;
private @Nullable Boolean suppressCors;

public SockJsServiceOptions setTaskScheduler(TaskScheduler taskScheduler) {
this.taskScheduler = taskScheduler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ class IntegrationServletWebSocketHandlerRegistry extends ServletWebSocketHandler

private final Map<WebSocketHandler, List<String>> dynamicRegistrations = new HashMap<>();

@SuppressWarnings("NullAway.Init")
private ApplicationContext applicationContext;

@SuppressWarnings("NullAway.Init")
private TaskScheduler sockJsTaskScheduler;

@SuppressWarnings("NullAway.Init")
private volatile IntegrationDynamicWebSocketHandlerMapping dynamicHandlerMapping;

IntegrationServletWebSocketHandlerRegistry() {
Expand Down Expand Up @@ -124,6 +127,7 @@ void removeRegistration(ServerWebSocketContainer serverWebSocketContainer) {

private static final class DynamicHandlerRegistrationProxy implements WebSocketHandlerRegistry {

@SuppressWarnings("NullAway.Init")
private IntegrationDynamicWebSocketHandlerRegistration registration;

DynamicHandlerRegistrationProxy() {
Expand All @@ -141,6 +145,7 @@ public WebSocketHandlerRegistration addHandler(WebSocketHandler webSocketHandler
private static final class IntegrationDynamicWebSocketHandlerRegistration
extends ServletWebSocketHandlerRegistration {

@SuppressWarnings("NullAway.Init")
private WebSocketHandler handler;

IntegrationDynamicWebSocketHandlerRegistration() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ private void registerEnableWebSocketIfNecessary(BeanDefinitionRegistry registry)

static class WebSocketHandlerMappingFactoryBean extends AbstractFactoryBean<HandlerMapping> {

@SuppressWarnings("NullAway.Init")
private IntegrationServletWebSocketHandlerRegistry registry;

@SuppressWarnings("NullAway.Init")
private ThreadPoolTaskScheduler sockJsTaskScheduler;

public void setRegistry(IntegrationServletWebSocketHandlerRegistry registry) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/**
* Contains parser classes for the WebSockets namespace support.
*/

@org.jspecify.annotations.NullMarked
package org.springframework.integration.websocket.config;
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/**
* Provides classes which represent WebSocket event components.
*/

@org.jspecify.annotations.NullMarked
package org.springframework.integration.websocket.event;
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import org.jspecify.annotations.Nullable;

import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
Expand Down Expand Up @@ -67,6 +68,7 @@
*
* @author Artem Bilan
* @author Ngoc Nhan
* @author Jooyoung Pyoung
*
* @since 4.1
*/
Expand Down Expand Up @@ -99,17 +101,18 @@ public class WebSocketInboundChannelAdapter extends MessageProducerSupport

private final MessageChannel subProtocolHandlerChannel;

private final AtomicReference<Class<?>> payloadType = new AtomicReference<>(String.class);
private Class<?> payloadType = String.class;

@SuppressWarnings("NullAway.Init")
private ApplicationEventPublisher eventPublisher;

private List<MessageConverter> messageConverters;
private @Nullable List<MessageConverter> messageConverters;

private boolean mergeWithDefaultConverters = false;

private boolean useBroker;

private AbstractBrokerMessageHandler brokerHandler;
private @Nullable AbstractBrokerMessageHandler brokerHandler;

public WebSocketInboundChannelAdapter(IntegrationWebSocketContainer webSocketContainer) {
this(webSocketContainer, new SubProtocolHandlerRegistry(new PassThruSubProtocolHandler()));
Expand Down Expand Up @@ -163,7 +166,7 @@ public void setMergeWithDefaultConverters(boolean mergeWithDefaultConverters) {
*/
public void setPayloadType(Class<?> payloadType) {
Assert.notNull(payloadType, "'payloadType' must not be null");
this.payloadType.set(payloadType);
this.payloadType = payloadType;
}

/**
Expand Down Expand Up @@ -303,7 +306,7 @@ public boolean isActive() {
return active;
}

@SuppressWarnings("unchecked")
@SuppressWarnings({"unchecked", "NullAway"}) // Dataflow analysis limitation
private void handleMessageAndSend(final Message<?> message) {
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor.wrap(message);
StompCommand stompCommand = (StompCommand) headerAccessor.getHeader("stompCommand");
Expand Down Expand Up @@ -335,8 +338,8 @@ else if (StompCommand.RECEIPT.equals(stompCommand)) {
}
}

private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccessor, StompCommand stompCommand,
SimpMessageType messageType) {
private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccessor, @Nullable StompCommand stompCommand,
@Nullable SimpMessageType messageType) {

return (messageType == null // NOSONAR pretty simple logic
|| SimpMessageType.MESSAGE.equals(messageType)
Expand All @@ -346,7 +349,8 @@ private boolean isProcessingTypeOrCommand(SimpMessageHeaderAccessor headerAccess
&& !checkDestinationPrefix(headerAccessor.getDestination());
}

private boolean checkDestinationPrefix(String destination) {
@SuppressWarnings("NullAway") // Dataflow analysis limitation
private boolean checkDestinationPrefix(@Nullable String destination) {
if (this.useBroker) {
Collection<String> destinationPrefixes = this.brokerHandler.getDestinationPrefixes();
if ((destination == null) || CollectionUtils.isEmpty(destinationPrefixes)) {
Expand Down Expand Up @@ -374,11 +378,11 @@ private void produceConnectAckMessage(Message<?> message, SimpMessageHeaderAcces
}

private void produceMessage(Message<?> message, SimpMessageHeaderAccessor headerAccessor) {
Object payload = this.messageConverter.fromMessage(message, this.payloadType.get());
Object payload = this.messageConverter.fromMessage(message, this.payloadType);
Assert.state(payload != null,
() -> "The message converter '" + this.messageConverter +
"' produced no payload for message '" + message +
"' and expected payload type: " + this.payloadType.get());
"' and expected payload type: " + this.payloadType);
Message<Object> messageToSend =
getMessageBuilderFactory()
.withPayload(payload)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/**
* Provides classes which represent inbound WebSocket components.
*/

@org.jspecify.annotations.NullMarked
package org.springframework.integration.websocket.inbound;
Loading