diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java index 2b36664fbf5..83880336541 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurer.java @@ -33,6 +33,7 @@ import org.springframework.security.web.authentication.rememberme.PersistentTokenRepository; import org.springframework.security.web.authentication.rememberme.RememberMeAuthenticationFilter; import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices; +import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; @@ -296,6 +297,13 @@ public void configure(H http) { rememberMeFilter.setSecurityContextRepository(securityContextRepository); } rememberMeFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); + + SessionAuthenticationStrategy sessionAuthenticationStrategy = http + .getSharedObject(SessionAuthenticationStrategy.class); + if (sessionAuthenticationStrategy != null) { + rememberMeFilter.setSessionAuthenticationStrategy(sessionAuthenticationStrategy); + } + rememberMeFilter = postProcess(rememberMeFilter); http.addFilter(rememberMeFilter); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java index 9c39f2b01c5..e3cb83f76fd 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/RememberMeConfigurerTests.java @@ -60,6 +60,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; @@ -74,6 +75,7 @@ import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.cookie; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; @@ -334,6 +336,27 @@ public void getWhenCustomSecurityContextRepositoryThenUses() throws Exception { verify(repository).saveContext(any(), any(), any()); } + @Test + public void rememberMeExpiresSessionWhenSessionManagementMaximumSessionsExceeds() throws Exception { + this.spring.register(RememberMeMaximumSessionsConfig.class).autowire(); + + MockHttpServletRequestBuilder loginRequest = post("/login").with(csrf()) + .param("username", "user") + .param("password", "password") + .param("remember-me", "true"); + MvcResult mvcResult = this.mvc.perform(loginRequest).andReturn(); + Cookie rememberMeCookie = mvcResult.getResponse().getCookie("remember-me"); + HttpSession session = mvcResult.getRequest().getSession(); + + MockHttpServletRequestBuilder exceedsMaximumSessionsRequest = get("/abc").cookie(rememberMeCookie); + this.mvc.perform(exceedsMaximumSessionsRequest); + + MockHttpServletRequestBuilder sessionExpiredRequest = get("/abc").cookie(rememberMeCookie) + .session((MockHttpSession) session); + this.mvc.perform(sessionExpiredRequest) + .andExpect(content().string(startsWith("This session has been expired"))); + } + @Configuration @EnableWebSecurity static class NullUserDetailsConfig { @@ -617,6 +640,35 @@ SecurityFilterChain filterChain(HttpSecurity http) throws Exception { } + @Configuration + @EnableWebSecurity + static class RememberMeMaximumSessionsConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeRequests((authorizeRequests) -> + authorizeRequests + .anyRequest().hasRole("USER") + ) + .sessionManagement((sessionManagement) -> + sessionManagement + .maximumSessions(1) + ) + .formLogin(withDefaults()) + .rememberMe(withDefaults()); + return http.build(); + // @formatter:on + } + + @Bean + UserDetailsService userDetailsService() { + return new InMemoryUserDetailsManager(PasswordEncodedUser.user()); + } + + } + @Configuration @EnableWebSecurity static class SecurityContextRepositoryConfig { diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java index 164c1fc8c64..f111d20156d 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java @@ -37,6 +37,8 @@ import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.RememberMeServices; +import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy; +import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.util.Assert; @@ -81,6 +83,8 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements private SecurityContextRepository securityContextRepository = new HttpSessionSecurityContextRepository(); + private SessionAuthenticationStrategy sessionStrategy = new NullAuthenticatedSessionStrategy(); + public RememberMeAuthenticationFilter(AuthenticationManager authenticationManager, RememberMeServices rememberMeServices) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); @@ -115,6 +119,7 @@ private void doFilter(HttpServletRequest request, HttpServletResponse response, // Attempt authentication via AuthenticationManager try { rememberMeAuth = this.authenticationManager.authenticate(rememberMeAuth); + this.sessionStrategy.onAuthentication(rememberMeAuth, request, response); // Store to SecurityContextHolder SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); context.setAuthentication(rememberMeAuth); @@ -211,4 +216,17 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur this.securityContextHolderStrategy = securityContextHolderStrategy; } + /** + * The session handling strategy which will be invoked immediately after an + * authentication request is successfully processed by the + * AuthenticationManager. Used, for example, to handle changing of the + * session identifier to prevent session fixation attacks. + * @param sessionStrategy the implementation to use. If not set a null implementation + * is used. + */ + public void setSessionAuthenticationStrategy(SessionAuthenticationStrategy sessionStrategy) { + Assert.notNull(sessionStrategy, "sessionStrategy cannot be null"); + this.sessionStrategy = sessionStrategy; + } + } diff --git a/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java index 0612e48a9ab..7652477054c 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java @@ -35,6 +35,7 @@ import org.springframework.security.web.authentication.NullRememberMeServices; import org.springframework.security.web.authentication.RememberMeServices; import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler; +import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.context.SecurityContextRepository; import static org.assertj.core.api.Assertions.assertThat; @@ -170,6 +171,23 @@ public void securityContextRepositoryInvokedIfSet() throws Exception { verify(securityContextRepository).saveContext(any(), eq(request), eq(response)); } + @Test + public void sessionAuthenticationStrategyInvokedIfSet() throws Exception { + SessionAuthenticationStrategy sessionAuthenticationStrategy = mock(SessionAuthenticationStrategy.class); + AuthenticationManager am = mock(AuthenticationManager.class); + given(am.authenticate(this.remembered)).willReturn(this.remembered); + RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(am, + new MockRememberMeServices(this.remembered)); + filter.setAuthenticationSuccessHandler(new SimpleUrlAuthenticationSuccessHandler("/target")); + filter.setSessionAuthenticationStrategy(sessionAuthenticationStrategy); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain fc = mock(FilterChain.class); + request.setRequestURI("x"); + filter.doFilter(request, response, fc); + verify(sessionAuthenticationStrategy).onAuthentication(any(), eq(request), eq(response)); + } + private class MockRememberMeServices implements RememberMeServices { private Authentication authToReturn;